Support categorical data for hist. (#7695)

* Extract partitioner from hist.
* Implement categorical data support by passing the gradient index directly into the partitioner.
* Organize/update document.
* Remove code for negative hessian.
This commit is contained in:
Jiaming Yuan 2022-02-25 03:47:14 +08:00 committed by GitHub
parent f60d95b0ba
commit 83a66b4994
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 402 additions and 498 deletions

View File

@ -244,9 +244,6 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
- Use single precision to build histograms instead of double precision. - Use single precision to build histograms instead of double precision.
Additional parameters for ``approx`` and ``gpu_hist`` tree method
=================================================================
* ``max_cat_to_onehot`` * ``max_cat_to_onehot``
.. versionadded:: 1.6 .. versionadded:: 1.6
@ -256,8 +253,8 @@ Additional parameters for ``approx`` and ``gpu_hist`` tree method
- A threshold for deciding whether XGBoost should use one-hot encoding based split for - A threshold for deciding whether XGBoost should use one-hot encoding based split for
categorical data. When number of categories is lesser than the threshold then one-hot categorical data. When number of categories is lesser than the threshold then one-hot
encoding is chosen, otherwise the categories will be partitioned into children nodes. encoding is chosen, otherwise the categories will be partitioned into children nodes.
Only relevant for regression and binary classification. Also, `approx` or `gpu_hist` Only relevant for regression and binary classification. Also, ``exact`` tree method is
tree method is required. not supported
Additional parameters for Dart Booster (``booster=dart``) Additional parameters for Dart Booster (``booster=dart``)
========================================================= =========================================================

View File

@ -4,16 +4,16 @@ Categorical Data
.. note:: .. note::
As of XGBoost 1.6, the feature is highly experimental and has limited features As of XGBoost 1.6, the feature is experimental and has limited features
Starting from version 1.5, XGBoost has experimental support for categorical data available Starting from version 1.5, XGBoost has experimental support for categorical data available
for public testing. At the moment, the support is implemented as one-hot encoding based for public testing. For numerical data, the split condition is defined as :math:`value <
categorical tree splits. For numerical data, the split condition is defined as threshold`, while for categorical data the split is defined depending on whether
:math:`value < threshold`, while for categorical data the split is defined as :math:`value partitioning or onehot encoding is used. For partition-based splits, the splits are
== category` and ``category`` is a discrete value. More advanced categorical split specified as :math:`value \in categories`, where ``categories`` is the set of categories
strategy is planned for future releases and this tutorial details how to inform XGBoost in one feature. If onehot encoding is used instead, then the split is defined as
about the data type. Also, the current support for training is limited to ``gpu_hist`` :math:`value == category`. More advanced categorical split strategy is planned for future
tree method. releases and this tutorial details how to inform XGBoost about the data type.
************************************ ************************************
Training with scikit-learn Interface Training with scikit-learn Interface
@ -35,13 +35,13 @@ parameter ``enable_categorical``:
.. code:: python .. code:: python
# Only gpu_hist is supported for categorical data as mentioned previously # Supported tree methods are `gpu_hist`, `approx`, and `hist`.
clf = xgb.XGBClassifier( clf = xgb.XGBClassifier(
tree_method="gpu_hist", enable_categorical=True, use_label_encoder=False tree_method="gpu_hist", enable_categorical=True, use_label_encoder=False
) )
# X is the dataframe we created in previous snippet # X is the dataframe we created in previous snippet
clf.fit(X, y) clf.fit(X, y)
# Must use JSON for serialization, otherwise the information is lost # Must use JSON/UBJSON for serialization, otherwise the information is lost.
clf.save_model("categorical-model.json") clf.save_model("categorical-model.json")
@ -60,11 +60,37 @@ can plot the model and calculate the global feature importance:
The ``scikit-learn`` interface from dask is similar to single node version. The basic The ``scikit-learn`` interface from dask is similar to single node version. The basic
idea is create dataframe with category feature type, and tell XGBoost to use ``gpu_hist`` idea is create dataframe with category feature type, and tell XGBoost to use it by setting
with parameter ``enable_categorical``. See :ref:`sphx_glr_python_examples_categorical.py` the ``enable_categorical`` parameter. See :ref:`sphx_glr_python_examples_categorical.py`
for a worked example of using categorical data with ``scikit-learn`` interface. A for a worked example of using categorical data with ``scikit-learn`` interface with
comparison between using one-hot encoded data and XGBoost's categorical data support can one-hot encoding. A comparison between using one-hot encoded data and XGBoost's
be found :ref:`sphx_glr_python_examples_cat_in_the_dat.py`. categorical data support can be found :ref:`sphx_glr_python_examples_cat_in_the_dat.py`.
********************
Optimal Partitioning
********************
.. versionadded:: 1.6
Optimal partitioning is a technique for partitioning the categorical predictors for each
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
<#references>`__ brought it to the context of gradient boosting trees and now is also
adopted in XGBoost as an optional feature for handling categorical splits. More
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
partition a set of discrete values into groups based on the distances between a measure of
these values, one only needs to look at sorted partitions instead of enumerating all
possible permutations. In the context of decision trees, the discrete values are
categories, and the measure is the output leaf value. Intuitively, we want to group the
categories that output similar leaf values. During split finding, we first sort the
gradient histogram to prepare the contiguous partitions then enumerate the splits
according to these sorted values. One of the related parameters for XGBoost is
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
used for each feature, see :doc:`/parameter` for details. When objective is not
regression or binary classification, XGBoost will fallback to using onehot encoding
instead.
********************** **********************
@ -82,7 +108,7 @@ categorical data, we need to pass the similar parameter to :class:`DMatrix
# X is a dataframe we created in previous snippet # X is a dataframe we created in previous snippet
Xy = xgb.DMatrix(X, y, enable_categorical=True) Xy = xgb.DMatrix(X, y, enable_categorical=True)
booster = xgb.train({"tree_method": "gpu_hist"}, Xy) booster = xgb.train({"tree_method": "hist", "max_cat_to_onehot": 5}, Xy)
# Must use JSON for serialization, otherwise the information is lost # Must use JSON for serialization, otherwise the information is lost
booster.save_model("categorical-model.json") booster.save_model("categorical-model.json")
@ -109,30 +135,7 @@ types by using the ``feature_types`` parameter in :class:`DMatrix <xgboost.DMatr
For numerical data, the feature type can be ``"q"`` or ``"float"``, while for categorical For numerical data, the feature type can be ``"q"`` or ``"float"``, while for categorical
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
:class:`dask.Array <dask.Array>` can also be used as categorical data. :class:`dask.Array <dask.Array>` can also be used for categorical data.
********************
Optimal Partitioning
********************
.. versionadded:: 1.6
Optimal partitioning is a technique for partitioning the categorical predictors for each
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
<#references>`__ brought it to the context of gradient boosting trees and now is also
adopted in XGBoost as an optional feature for handling categorical splits. More
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
partition a set of discrete values into groups based on the distances between a measure of
these values, one only needs to look at sorted partitions instead of enumerating all
possible permutations. In the context of decision trees, the discrete values are
categories, and the measure is the output leaf value. Intuitively, we want to group the
categories that output similar leaf values. During split finding, we first sort the
gradient histogram to prepare the contiguous partitions then enumerate the splits
according to these sorted values. One of the related parameters for XGBoost is
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
used for each feature, see :doc:`/parameter` for details.
************* *************
Miscellaneous Miscellaneous

View File

@ -604,6 +604,16 @@ class RegTree : public Model {
*/ */
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; } std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; } common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
/*!
* \brief Get the bit storage for categories
*/
common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
auto node_ptr = GetCategoriesMatrix().node_ptr;
auto categories = GetCategoriesMatrix().categories;
auto segment = node_ptr[nidx];
auto node_cats = categories.subspan(segment.beg, segment.size);
return node_cats;
}
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; } auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
// The fields of split_categories_segments_[i] are set such that // The fields of split_categories_segments_[i] are set such that

View File

@ -582,10 +582,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
.. note:: This parameter is experimental
Experimental support of specializing for categorical features. Do not set Experimental support of specializing for categorical features. Do not set
to True unless you are interested in development. Currently it's only to True unless you are interested in development. Also, JSON/UBJSON
available for `gpu_hist` and `approx` tree methods. Also, JSON/UBJSON serialization format is required.
serialization format is required. (XGBoost 1.6 for approx)
""" """
if group is not None and qid is not None: if group is not None and qid is not None:

View File

@ -206,10 +206,11 @@ __model_doc = f'''
.. versionadded:: 1.5.0 .. versionadded:: 1.5.0
Experimental support for categorical data. Do not set to true unless you are .. note:: This parameter is experimental
interested in development. Only valid when `gpu_hist` or `approx` is used along
with dataframe as input. Also, JSON/UBJSON serialization format is Experimental support for categorical data. When enabled, cudf/pandas.DataFrame
required. (XGBoost 1.6 for approx) should be used to specify categorical data type. Also, JSON/UBJSON
serialization format is required.
max_cat_to_onehot : Optional[int] max_cat_to_onehot : Optional[int]
@ -220,9 +221,8 @@ __model_doc = f'''
A threshold for deciding whether XGBoost should use one-hot encoding based split A threshold for deciding whether XGBoost should use one-hot encoding based split
for categorical data. When number of categories is lesser than the threshold for categorical data. When number of categories is lesser than the threshold
then one-hot encoding is chosen, otherwise the categories will be partitioned then one-hot encoding is chosen, otherwise the categories will be partitioned
into children nodes. Only relevant for regression and binary into children nodes. Only relevant for regression and binary classification.
classification. Also, ``approx`` or ``gpu_hist`` tree method is required. See See :doc:`Categorical Data </tutorials/categorical>` for details.
:doc:`Categorical Data </tutorials/categorical>` for details.
eval_metric : Optional[Union[str, List[str], Callable]] eval_metric : Optional[Union[str, List[str], Callable]]
@ -846,7 +846,8 @@ class XGBModel(XGBModelBase):
callbacks = self.callbacks if self.callbacks is not None else callbacks callbacks = self.callbacks if self.callbacks is not None else callbacks
tree_method = params.get("tree_method", None) tree_method = params.get("tree_method", None)
if self.enable_categorical and tree_method not in ("gpu_hist", "approx"): cat_support = {"gpu_hist", "approx", "hist"}
if self.enable_categorical and tree_method not in cat_support:
raise ValueError( raise ValueError(
"Experimental support for categorical data is not implemented for" "Experimental support for categorical data is not implemented for"
" current tree method yet." " current tree method yet."

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 by Contributors * Copyright 2021-2022 by Contributors
* \file row_set.h * \file row_set.h
* \brief Quick Utility to compute subset of rows * \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen * \author Philip Cho, Tianqi Chen
@ -8,12 +8,15 @@
#define XGBOOST_COMMON_PARTITION_BUILDER_H_ #define XGBOOST_COMMON_PARTITION_BUILDER_H_
#include <xgboost/data.h> #include <xgboost/data.h>
#include <algorithm> #include <algorithm>
#include <vector>
#include <utility>
#include <memory> #include <memory>
#include <utility>
#include <vector>
#include "categorical.h"
#include "column_matrix.h"
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
#include "../common/column_matrix.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
@ -46,18 +49,24 @@ class PartitionBuilder {
// on comparison of indexes values (idx_span) and split point (split_cond) // on comparison of indexes values (idx_span) and split point (split_cond)
// Handle dense columns // Handle dense columns
// Analog of std::stable_partition, but in no-inplace manner // Analog of std::stable_partition, but in no-inplace manner
template <bool default_left, bool any_missing, typename ColumnType> template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
inline std::pair<size_t, size_t> PartitionKernel(const ColumnType& column, inline std::pair<size_t, size_t> PartitionKernel(const ColumnType& column,
common::Span<const size_t> rid_span, const int32_t split_cond, common::Span<const size_t> row_indices,
common::Span<size_t> left_part, common::Span<size_t> right_part) { common::Span<size_t> left_part,
common::Span<size_t> right_part,
size_t base_rowid, Predicate&& pred) {
size_t* p_left_part = left_part.data(); size_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data(); size_t* p_right_part = right_part.data();
size_t nleft_elems = 0; size_t nleft_elems = 0;
size_t nright_elems = 0; size_t nright_elems = 0;
auto state = column.GetInitialState(rid_span.front()); auto state = column.GetInitialState(row_indices.front() - base_rowid);
for (auto rid : rid_span) { auto p_row_indices = row_indices.data();
const int32_t bin_id = column.GetBinIdx(rid, &state); auto n_samples = row_indices.size();
for (size_t i = 0; i < n_samples; ++i) {
auto rid = p_row_indices[i];
const int32_t bin_id = column.GetBinIdx(rid - base_rowid, &state);
if (any_missing && bin_id == ColumnType::kMissingId) { if (any_missing && bin_id == ColumnType::kMissingId) {
if (default_left) { if (default_left) {
p_left_part[nleft_elems++] = rid; p_left_part[nleft_elems++] = rid;
@ -65,7 +74,7 @@ class PartitionBuilder {
p_right_part[nright_elems++] = rid; p_right_part[nright_elems++] = rid;
} }
} else { } else {
if (bin_id <= split_cond) { if (pred(rid, bin_id)) {
p_left_part[nleft_elems++] = rid; p_left_part[nleft_elems++] = rid;
} else { } else {
p_right_part[nright_elems++] = rid; p_right_part[nright_elems++] = rid;
@ -95,41 +104,66 @@ class PartitionBuilder {
return {nleft_elems, nright_elems}; return {nleft_elems, nright_elems};
} }
template <typename BinIdxType, bool any_missing> template <typename BinIdxType, bool any_missing, bool any_cat>
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range, void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
const int32_t split_cond, const int32_t split_cond, GHistIndexMatrix const& gmat,
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) { const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end()); common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set, common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
range.begin(), range.end()); common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set,
range.begin(), range.end());
const bst_uint fid = tree[nid].SplitIndex(); const bst_uint fid = tree[nid].SplitIndex();
const bool default_left = tree[nid].DefaultLeft(); const bool default_left = tree[nid].DefaultLeft();
const auto column_ptr = column_matrix.GetColumn<BinIdxType, any_missing>(fid); const auto column_ptr = column_matrix.GetColumn<BinIdxType, any_missing>(fid);
std::pair<size_t, size_t> child_nodes_sizes; bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
auto node_cats = tree.NodeCats(nid);
auto const& index = gmat.index;
auto const& cut_values = gmat.cut.Values();
auto const& cut_ptrs = gmat.cut.Ptrs();
auto pred = [&](auto ridx, auto bin_id) {
if (any_cat && is_cat) {
auto begin = gmat.RowIdx(ridx);
auto end = gmat.RowIdx(ridx + 1);
auto f_begin = cut_ptrs[fid];
auto f_end = cut_ptrs[fid + 1];
// bypassing the column matrix as we need the cut value instead of bin idx for categorical
// features.
auto gidx = BinarySearchBin(begin, end, index, f_begin, f_end);
bool go_left;
if (gidx == -1) {
go_left = default_left;
} else {
go_left = Decision(node_cats, cut_values[gidx], default_left);
}
return go_left;
} else {
return bin_id <= split_cond;
}
};
std::pair<size_t, size_t> child_nodes_sizes;
if (column_ptr->GetType() == xgboost::common::kDenseColumn) { if (column_ptr->GetType() == xgboost::common::kDenseColumn) {
const common::DenseColumn<BinIdxType, any_missing>& column = const common::DenseColumn<BinIdxType, any_missing>& column =
static_cast<const common::DenseColumn<BinIdxType, any_missing>& >(*(column_ptr.get())); static_cast<const common::DenseColumn<BinIdxType, any_missing>& >(*(column_ptr.get()));
if (default_left) { if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
split_cond, left, right); gmat.base_rowid, pred);
} else { } else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
split_cond, left, right); gmat.base_rowid, pred);
} }
} else { } else {
CHECK_EQ(any_missing, true); CHECK_EQ(any_missing, true);
const common::SparseColumn<BinIdxType>& column const common::SparseColumn<BinIdxType>& column
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get())); = static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
if (default_left) { if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
split_cond, left, right); gmat.base_rowid, pred);
} else { } else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
split_cond, left, right); gmat.base_rowid, pred);
} }
} }

View File

@ -275,9 +275,6 @@ class MemStackAllocator {
T& operator[](size_t i) { return ptr_[i]; } T& operator[](size_t i) { return ptr_[i]; }
T const& operator[](size_t i) const { return ptr_[i]; } T const& operator[](size_t i) const { return ptr_[i]; }
// FIXME(jiamingy): Remove this once we merge partitioner cleanup for hist.
auto Get() { return ptr_; }
private: private:
T* ptr_ = nullptr; T* ptr_ = nullptr;
size_t required_size_; size_t required_size_;

View File

@ -288,10 +288,10 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
auto base_weight = auto base_weight =
evaluator.CalcWeight(candidate.nid, param_, GradStats{parent_sum}); evaluator.CalcWeight(candidate.nid, param_, GradStats{parent_sum});
auto left_weight = evaluator.CalcWeight( auto left_weight =
candidate.nid, param_, GradStats{candidate.split.left_sum}); evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.left_sum});
auto right_weight = evaluator.CalcWeight( auto right_weight =
candidate.nid, param_, GradStats{candidate.split.right_sum}); evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum});
if (candidate.split.is_cat) { if (candidate.split.is_cat) {
std::vector<uint32_t> split_cats; std::vector<uint32_t> split_cats;
@ -308,11 +308,11 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
split_cats = candidate.split.cat_bits; split_cats = candidate.split.cat_bits;
common::CatBitField cat_bits{split_cats}; common::CatBitField cat_bits{split_cats};
} }
tree.ExpandCategorical( tree.ExpandCategorical(
candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(), candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(),
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(), base_weight, left_weight * param_.learning_rate, right_weight * param_.learning_rate,
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.left_sum.GetHess(),
candidate.split.right_sum.GetHess());
} else { } else {
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value, tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
candidate.split.DefaultLeft(), base_weight, candidate.split.DefaultLeft(), base_weight,

View File

@ -124,11 +124,12 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
nodes_for_subtraction_trick_.clear(); nodes_for_subtraction_trick_.clear();
nodes_for_explicit_hist_build_.push_back(node); nodes_for_explicit_hist_build_.push_back(node);
auto const& row_set_collection = partitioner_.front().Partitions();
size_t page_id = 0; size_t page_id = 0;
for (auto const& gidx : for (auto const& gidx :
p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) { p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
this->histogram_builder_->BuildHist( this->histogram_builder_->BuildHist(
page_id, gidx, p_tree, row_set_collection_, page_id, gidx, p_tree, row_set_collection,
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h); nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h);
++page_id; ++page_id;
} }
@ -149,7 +150,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
grad_stat.Add(et.GetGrad(), et.GetHess()); grad_stat.Add(et.GetGrad(), et.GetHess());
} }
} else { } else {
const common::RowSetCollection::Elem e = row_set_collection_[nid]; const common::RowSetCollection::Elem e = row_set_collection[nid];
for (const size_t *it = e.begin; it < e.end; ++it) { for (const size_t *it = e.begin; it < e.end; ++it) {
grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess()); grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess());
} }
@ -204,6 +205,7 @@ void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
const std::vector<CPUExpandEntry> &nodes_for_apply_split, const std::vector<CPUExpandEntry> &nodes_for_apply_split,
std::vector<CPUExpandEntry> *nodes_to_evaluate, RegTree *p_tree) { std::vector<CPUExpandEntry> *nodes_to_evaluate, RegTree *p_tree) {
builder_monitor_.Start("SplitSiblings"); builder_monitor_.Start("SplitSiblings");
auto const& row_set_collection = this->partitioner_.front().Partitions();
for (auto const& entry : nodes_for_apply_split) { for (auto const& entry : nodes_for_apply_split) {
int nid = entry.nid; int nid = entry.nid;
@ -213,7 +215,7 @@ void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0); const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0);
nodes_to_evaluate->push_back(left_node); nodes_to_evaluate->push_back(left_node);
nodes_to_evaluate->push_back(right_node); nodes_to_evaluate->push_back(right_node);
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { if (row_set_collection[cleft].Size() < row_set_collection[cright].Size()) {
nodes_for_explicit_hist_build_.push_back(left_node); nodes_for_explicit_hist_build_.push_back(left_node);
nodes_for_subtraction_trick_.push_back(right_node); nodes_for_subtraction_trick_.push_back(right_node);
} else { } else {
@ -253,16 +255,23 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
AddSplitsToTree(expand, p_tree, &num_leaves, &nodes_for_apply_split); AddSplitsToTree(expand, p_tree, &num_leaves, &nodes_for_apply_split);
if (nodes_for_apply_split.size() != 0) { if (nodes_for_apply_split.size() != 0) {
ApplySplit<any_missing>(nodes_for_apply_split, gmat, column_matrix, p_tree); HistRowPartitioner &partitioner = this->partitioner_.front();
if (gmat.cut.HasCategorical()) {
partitioner.UpdatePosition<any_missing, true>(this->ctx_, gmat, column_matrix,
nodes_for_apply_split, p_tree);
} else {
partitioner.UpdatePosition<any_missing, false>(this->ctx_, gmat, column_matrix,
nodes_for_apply_split, p_tree);
}
SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree); SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree);
if (param_.max_depth == 0 || depth < param_.max_depth) { if (param_.max_depth == 0 || depth < param_.max_depth) {
size_t i = 0; size_t i = 0;
for (auto const& gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) { for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
this->histogram_builder_->BuildHist( this->histogram_builder_->BuildHist(i, gidx, p_tree, partitioner_.front().Partitions(),
i, gidx, p_tree, row_set_collection_, nodes_for_explicit_hist_build_,
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, nodes_for_subtraction_trick_, gpair_h);
gpair_h);
++i; ++i;
} }
} else { } else {
@ -293,7 +302,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
template <typename GradientSumT> template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::Update( void QuantileHistMaker::Builder<GradientSumT>::Update(
const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat,
const ColumnMatrix &column_matrix, const common::ColumnMatrix &column_matrix,
HostDeviceVector<GradientPair> *gpair, HostDeviceVector<GradientPair> *gpair,
DMatrix *p_fmat, RegTree *p_tree) { DMatrix *p_fmat, RegTree *p_tree) {
builder_monitor_.Start("Update"); builder_monitor_.Start("Update");
@ -333,14 +342,14 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
CHECK_GT(out_preds.Size(), 0U); CHECK_GT(out_preds.Size(), 0U);
size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin(); CHECK_EQ(partitioner_.size(), 1);
auto const &row_set_collection = this->partitioner_.front().Partitions();
common::BlockedSpace2d space(n_nodes, [&](size_t node) { size_t n_nodes = row_set_collection.end() - row_set_collection.begin();
return row_set_collection_[node].Size(); common::BlockedSpace2d space(
}, 1024); n_nodes, [&](size_t node) { return partitioner_.front()[node].Size(); }, 1024);
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId); CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node, common::Range1d r) { common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node, common::Range1d r) {
const RowSetCollection::Elem rowset = row_set_collection_[node]; const common::RowSetCollection::Elem rowset = row_set_collection[node];
if (rowset.begin != nullptr && rowset.end != nullptr) { if (rowset.begin != nullptr && rowset.end != nullptr) {
int nid = rowset.node_id; int nid = rowset.node_id;
bst_float leaf_value; bst_float leaf_value;
@ -354,7 +363,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
} }
leaf_value = (*p_last_tree_)[nid].LeafValue(); leaf_value = (*p_last_tree_)[nid].LeafValue();
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value; out_preds(*it) += leaf_value;
} }
} }
@ -364,10 +373,9 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
return true; return true;
} }
template<typename GradientSumT> template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix& fmat, void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix& fmat,
std::vector<GradientPair>* gpair, std::vector<GradientPair>* gpair) {
std::vector<size_t>* row_indices) {
const auto& info = fmat.Info(); const auto& info = fmat.Info();
auto& rnd = common::GlobalRandom(); auto& rnd = common::GlobalRandom();
std::vector<GradientPair>& gpair_ref = *gpair; std::vector<GradientPair>& gpair_ref = *gpair;
@ -410,101 +418,31 @@ template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitData( void QuantileHistMaker::Builder<GradientSumT>::InitData(
const GHistIndexMatrix &gmat, const DMatrix &fmat, const RegTree &tree, const GHistIndexMatrix &gmat, const DMatrix &fmat, const RegTree &tree,
std::vector<GradientPair> *gpair) { std::vector<GradientPair> *gpair) {
CHECK((param_.max_depth > 0 || param_.max_leaves > 0))
<< "max_depth or max_leaves cannot be both 0 (unlimited); "
<< "at least one should be a positive quantity.";
if (param_.grow_policy == TrainParam::kDepthWise) {
CHECK(param_.max_depth > 0) << "max_depth cannot be 0 (unlimited) "
<< "when grow_policy is depthwise.";
}
builder_monitor_.Start("InitData"); builder_monitor_.Start("InitData");
const auto& info = fmat.Info(); const auto& info = fmat.Info();
{ {
// initialize the row set
row_set_collection_.Clear();
// initialize histogram collection // initialize histogram collection
uint32_t nbins = gmat.cut.Ptrs().back(); uint32_t nbins = gmat.cut.Ptrs().back();
// initialize histogram builder // initialize histogram builder
dmlc::OMPException exc; dmlc::OMPException exc;
exc.Rethrow();
this->histogram_builder_->Reset(nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin}, this->histogram_builder_->Reset(nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin},
this->ctx_->Threads(), 1, rabit::IsDistributed()); this->ctx_->Threads(), 1, rabit::IsDistributed());
std::vector<size_t>& row_indices = *row_set_collection_.Data();
row_indices.resize(info.num_row_);
size_t* p_row_indices = row_indices.data();
// mark subsample and build list of member rows
if (param_.subsample < 1.0f) { if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform) CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, " << "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist."; << "gradient-based sampling is only support by GPU Hist.";
builder_monitor_.Start("InitSampling"); builder_monitor_.Start("InitSampling");
InitSampling(fmat, gpair, &row_indices); InitSampling(fmat, gpair);
builder_monitor_.Stop("InitSampling"); builder_monitor_.Stop("InitSampling");
CHECK_EQ(row_indices.size(), info.num_row_);
// We should check that the partitioning was done correctly // We should check that the partitioning was done correctly
// and each row of the dataset fell into exactly one of the categories // and each row of the dataset fell into exactly one of the categories
} }
auto n_threads = this->ctx_->Threads();
common::MemStackAllocator<bool, 128> buff(n_threads);
bool* p_buff = buff.Get();
std::fill(p_buff, p_buff + this->ctx_->Threads(), false);
const size_t block_size = info.num_row_ / n_threads + !!(info.num_row_ % n_threads);
#pragma omp parallel num_threads(n_threads)
{
exc.Run([&]() {
const size_t tid = omp_get_thread_num();
const size_t ibegin = tid * block_size;
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
static_cast<size_t>(info.num_row_));
for (size_t i = ibegin; i < iend; ++i) {
if ((*gpair)[i].GetHess() < 0.0f) {
p_buff[tid] = true;
break;
}
}
});
}
exc.Rethrow();
bool has_neg_hess = false;
for (int32_t tid = 0; tid < n_threads; ++tid) {
if (p_buff[tid]) {
has_neg_hess = true;
}
}
if (has_neg_hess) {
size_t j = 0;
for (size_t i = 0; i < info.num_row_; ++i) {
if ((*gpair)[i].GetHess() >= 0.0f) {
p_row_indices[j++] = i;
}
}
row_indices.resize(j);
} else {
#pragma omp parallel num_threads(n_threads)
{
exc.Run([&]() {
const size_t tid = omp_get_thread_num();
const size_t ibegin = tid * block_size;
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
static_cast<size_t>(info.num_row_));
for (size_t i = ibegin; i < iend; ++i) {
p_row_indices[i] = i;
}
});
}
exc.Rethrow();
}
} }
row_set_collection_.Init(); partitioner_.clear();
partitioner_.emplace_back(info.num_row_, 0, this->ctx_->Threads());
{ {
/* determine layout of data */ /* determine layout of data */
@ -558,12 +496,9 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
builder_monitor_.Stop("InitData"); builder_monitor_.Stop("InitData");
} }
template <typename GradientSumT> void HistRowPartitioner::FindSplitConditions(const std::vector<CPUExpandEntry> &nodes,
void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions( const RegTree &tree, const GHistIndexMatrix &gmat,
const std::vector<CPUExpandEntry>& nodes, std::vector<int32_t> *split_conditions) {
const RegTree& tree,
const GHistIndexMatrix& gmat,
std::vector<int32_t>* split_conditions) {
const size_t n_nodes = nodes.size(); const size_t n_nodes = nodes.size();
split_conditions->resize(n_nodes); split_conditions->resize(n_nodes);
@ -576,8 +511,7 @@ void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions(
int32_t split_cond = -1; int32_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id // convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points // split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound, CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (uint32_t bound = lower_bound; bound < upper_bound; ++bound) { for (uint32_t bound = lower_bound; bound < upper_bound; ++bound) {
if (split_pt == gmat.cut.Values()[bound]) { if (split_pt == gmat.cut.Values()[bound]) {
split_cond = static_cast<int32_t>(bound); split_cond = static_cast<int32_t>(bound);
@ -586,88 +520,20 @@ void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions(
(*split_conditions)[i] = split_cond; (*split_conditions)[i] = split_cond;
} }
} }
template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToRowSet( void HistRowPartitioner::AddSplitsToRowSet(const std::vector<CPUExpandEntry> &nodes,
const std::vector<CPUExpandEntry>& nodes, RegTree const *p_tree) {
RegTree* p_tree) {
const size_t n_nodes = nodes.size(); const size_t n_nodes = nodes.size();
for (unsigned int i = 0; i < n_nodes; ++i) { for (unsigned int i = 0; i < n_nodes; ++i) {
const int32_t nid = nodes[i].nid; const int32_t nid = nodes[i].nid;
const size_t n_left = partition_builder_.GetNLeftElems(i); const size_t n_left = partition_builder_.GetNLeftElems(i);
const size_t n_right = partition_builder_.GetNRightElems(i); const size_t n_right = partition_builder_.GetNRightElems(i);
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild()); CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild());
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(),
(*p_tree)[nid].RightChild(), n_left, n_right); n_left, n_right);
} }
} }
template <typename GradientSumT>
template <bool any_missing>
void QuantileHistMaker::Builder<GradientSumT>::ApplySplit(const std::vector<CPUExpandEntry> nodes,
const GHistIndexMatrix& gmat,
const ColumnMatrix& column_matrix,
RegTree* p_tree) {
builder_monitor_.Start("ApplySplit");
// 1. Find split condition for each split
const size_t n_nodes = nodes.size();
std::vector<int32_t> split_conditions;
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
// 2.1 Create a blocked space of size SUM(samples in each node)
common::BlockedSpace2d space(n_nodes, [&](size_t node_in_set) {
int32_t nid = nodes[node_in_set].nid;
return row_set_collection_[nid].Size();
}, kPartitionBlockSize);
// 2.2 Initialize the partition builder
// allocate buffers for storage intermediate results by each thread
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
const int32_t nid = nodes[node_in_set].nid;
const size_t size = row_set_collection_[nid].Size();
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
return n_tasks;
});
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
// Store results in intermediate buffers from partition_builder_
common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node_in_set, common::Range1d r) {
size_t begin = r.begin();
const int32_t nid = nodes[node_in_set].nid;
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
partition_builder_.AllocateForTask(task_id);
switch (column_matrix.GetTypeSize()) {
case common::kUint8BinsTypeSize:
partition_builder_.template Partition<uint8_t, any_missing>(node_in_set, nid, r,
split_conditions[node_in_set], column_matrix,
*p_tree, row_set_collection_[nid].begin);
break;
case common::kUint16BinsTypeSize:
partition_builder_.template Partition<uint16_t, any_missing>(node_in_set, nid, r,
split_conditions[node_in_set], column_matrix,
*p_tree, row_set_collection_[nid].begin);
break;
case common::kUint32BinsTypeSize:
partition_builder_.template Partition<uint32_t, any_missing>(node_in_set, nid, r,
split_conditions[node_in_set], column_matrix,
*p_tree, row_set_collection_[nid].begin);
break;
default:
CHECK(false); // no default behavior
}
});
// 3. Compute offsets to copy blocks of row-indexes
// from partition_builder_ to row_set_collection_
partition_builder_.CalculateRowOffsets();
// 4. Copy elements from partition_builder_ to row_set_collection_ back
// with updated row-indexes for each tree-node
common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node_in_set, common::Range1d r) {
const int32_t nid = nodes[node_in_set].nid;
partition_builder_.MergeToArray(node_in_set, r.begin(),
const_cast<size_t*>(row_set_collection_[nid].begin));
});
// 5. Add info about splits into row_set_collection_
AddSplitsToRowSet(nodes, p_tree);
builder_monitor_.Stop("ApplySplit");
}
template struct QuantileHistMaker::Builder<float>; template struct QuantileHistMaker::Builder<float>;
template struct QuantileHistMaker::Builder<double>; template struct QuantileHistMaker::Builder<double>;

View File

@ -11,9 +11,9 @@
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <iomanip> #include <algorithm>
#include <limits>
#include <memory> #include <memory>
#include <queue>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -38,8 +38,6 @@
#include "../common/column_matrix.h" #include "../common/column_matrix.h"
namespace xgboost { namespace xgboost {
struct RandomReplace { struct RandomReplace {
public: public:
// similar value as for minstd_rand // similar value as for minstd_rand
@ -82,15 +80,127 @@ struct RandomReplace {
}; };
namespace tree { namespace tree {
class HistRowPartitioner {
// heuristically chosen block size of parallel partitioning
static constexpr size_t kPartitionBlockSize = 2048;
// worker class that partition a block of rows
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
// storage for row index
common::RowSetCollection row_set_collection_;
using xgboost::GHistIndexMatrix; /**
using xgboost::common::GHistIndexRow; * \brief Turn split values into discrete bin indices.
using xgboost::common::HistCollection; */
using xgboost::common::RowSetCollection; static void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
using xgboost::common::GHistRow; const GHistIndexMatrix& gmat,
using xgboost::common::GHistBuilder; std::vector<int32_t>* split_conditions);
using xgboost::common::ColumnMatrix; /**
using xgboost::common::Column; * \brief Update the row set for new splits specifed by nodes.
*/
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree const* p_tree);
public:
bst_row_t base_rowid = 0;
public:
HistRowPartitioner(size_t n_samples, size_t base_rowid, int32_t n_threads) {
row_set_collection_.Clear();
const size_t block_size = n_samples / n_threads + !!(n_samples % n_threads);
dmlc::OMPException exc;
std::vector<size_t>& row_indices = *row_set_collection_.Data();
row_indices.resize(n_samples);
size_t* p_row_indices = row_indices.data();
// parallel initialization o f row indices. (std::iota)
#pragma omp parallel num_threads(n_threads)
{
exc.Run([&]() {
const size_t tid = omp_get_thread_num();
const size_t ibegin = tid * block_size;
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size), n_samples);
for (size_t i = ibegin; i < iend; ++i) {
p_row_indices[i] = i + base_rowid;
}
});
}
row_set_collection_.Init();
this->base_rowid = base_rowid;
}
template <bool any_missing, bool any_cat>
void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& gmat,
common::ColumnMatrix const& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
// 1. Find split condition for each split
const size_t n_nodes = nodes.size();
std::vector<int32_t> split_conditions;
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
// 2.1 Create a blocked space of size SUM(samples in each node)
common::BlockedSpace2d space(
n_nodes,
[&](size_t node_in_set) {
int32_t nid = nodes[node_in_set].nid;
return row_set_collection_[nid].Size();
},
kPartitionBlockSize);
// 2.2 Initialize the partition builder
// allocate buffers for storage intermediate results by each thread
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
const int32_t nid = nodes[node_in_set].nid;
const size_t size = row_set_collection_[nid].Size();
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
return n_tasks;
});
CHECK_EQ(base_rowid, gmat.base_rowid);
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
// Store results in intermediate buffers from partition_builder_
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
size_t begin = r.begin();
const int32_t nid = nodes[node_in_set].nid;
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
partition_builder_.AllocateForTask(task_id);
switch (column_matrix.GetTypeSize()) {
case common::kUint8BinsTypeSize:
partition_builder_.template Partition<uint8_t, any_missing, any_cat>(
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
row_set_collection_[nid].begin);
break;
case common::kUint16BinsTypeSize:
partition_builder_.template Partition<uint16_t, any_missing, any_cat>(
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
row_set_collection_[nid].begin);
break;
case common::kUint32BinsTypeSize:
partition_builder_.template Partition<uint32_t, any_missing, any_cat>(
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
row_set_collection_[nid].begin);
break;
default:
// no default behavior
CHECK(false) << column_matrix.GetTypeSize();
}
});
// 3. Compute offsets to copy blocks of row-indexes
// from partition_builder_ to row_set_collection_
partition_builder_.CalculateRowOffsets();
// 4. Copy elements from partition_builder_ to row_set_collection_ back
// with updated row-indexes for each tree-node
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
const int32_t nid = nodes[node_in_set].nid;
partition_builder_.MergeToArray(node_in_set, r.begin(),
const_cast<size_t*>(row_set_collection_[nid].begin));
});
// 5. Add info about splits into row_set_collection_
AddSplitsToRowSet(nodes, p_tree);
}
auto const& Partitions() const { return row_set_collection_; }
size_t Size() const {
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
}
auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
};
inline BatchParam HistBatch(TrainParam const& param) { inline BatchParam HistBatch(TrainParam const& param) {
return {param.max_bin, param.sparse_threshold}; return {param.max_bin, param.sparse_threshold};
@ -185,21 +295,7 @@ class QuantileHistMaker: public TreeUpdater {
size_t GetNumberOfTrees(); size_t GetNumberOfTrees();
void InitSampling(const DMatrix& fmat, void InitSampling(const DMatrix& fmat, std::vector<GradientPair>* gpair);
std::vector<GradientPair>* gpair,
std::vector<size_t>* row_indices);
template <bool any_missing>
void ApplySplit(std::vector<CPUExpandEntry> nodes,
const GHistIndexMatrix& gmat,
const ColumnMatrix& column_matrix,
RegTree* p_tree);
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree* p_tree);
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions);
template <bool any_missing> template <bool any_missing>
void InitRoot(DMatrix* p_fmat, void InitRoot(DMatrix* p_fmat,
@ -221,7 +317,7 @@ class QuantileHistMaker: public TreeUpdater {
template <bool any_missing> template <bool any_missing>
void ExpandTree(const GHistIndexMatrix& gmat, void ExpandTree(const GHistIndexMatrix& gmat,
const ColumnMatrix& column_matrix, const common::ColumnMatrix& column_matrix,
DMatrix* p_fmat, DMatrix* p_fmat,
RegTree* p_tree, RegTree* p_tree,
const std::vector<GradientPair>& gpair_h); const std::vector<GradientPair>& gpair_h);
@ -232,9 +328,6 @@ class QuantileHistMaker: public TreeUpdater {
std::shared_ptr<common::ColumnSampler> column_sampler_{ std::shared_ptr<common::ColumnSampler> column_sampler_{
std::make_shared<common::ColumnSampler>()}; std::make_shared<common::ColumnSampler>()};
std::vector<size_t> unused_rows_;
// the internal row sets
RowSetCollection row_set_collection_;
std::vector<GradientPair> gpair_local_; std::vector<GradientPair> gpair_local_;
/*! \brief feature with least # of bins. to be used for dense specialization /*! \brief feature with least # of bins. to be used for dense specialization
@ -243,12 +336,12 @@ class QuantileHistMaker: public TreeUpdater {
std::unique_ptr<TreeUpdater> pruner_; std::unique_ptr<TreeUpdater> pruner_;
std::unique_ptr<HistEvaluator<GradientSumT, CPUExpandEntry>> evaluator_; std::unique_ptr<HistEvaluator<GradientSumT, CPUExpandEntry>> evaluator_;
// Right now there's only 1 partitioner in this vector, when external memory is fully
static constexpr size_t kPartitionBlockSize = 2048; // supported we will have number of partitioners equal to number of pages.
common::PartitionBuilder<kPartitionBlockSize> partition_builder_; std::vector<HistRowPartitioner> partitioner_;
// back pointers to tree and data matrix // back pointers to tree and data matrix
const RegTree* p_last_tree_; const RegTree* p_last_tree_{nullptr};
DMatrix const* const p_last_fmat_; DMatrix const* const p_last_fmat_;
DMatrix* p_last_fmat_mutable_; DMatrix* p_last_fmat_mutable_;

View File

@ -40,7 +40,7 @@ template <typename GradientSumT> void TestEvaluateSplits() {
std::iota(row_indices.begin(), row_indices.end(), 0); std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection.Init(); row_set_collection.Init();
auto hist_builder = GHistBuilder<GradientSumT>(gmat.cut.Ptrs().back()); auto hist_builder = common::GHistBuilder<GradientSumT>(gmat.cut.Ptrs().back());
hist.Init(gmat.cut.Ptrs().back()); hist.Init(gmat.cut.Ptrs().back());
hist.AddHistRow(0); hist.AddHistRow(0);
hist.AllocateAllData(); hist.AllocateAllData();
@ -94,7 +94,7 @@ TEST(HistEvaluator, Apply) {
RegTree tree; RegTree tree;
int static constexpr kNRows = 8, kNCols = 16; int static constexpr kNRows = 8, kNCols = 16;
TrainParam param; TrainParam param;
param.UpdateAllowUnknown(Args{{}}); param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}});
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix(); auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
auto sampler = std::make_shared<common::ColumnSampler>(); auto sampler = std::make_shared<common::ColumnSampler>();
auto evaluator_ = HistEvaluator<float, CPUExpandEntry>{param, dmat->Info(), 4, sampler, auto evaluator_ = HistEvaluator<float, CPUExpandEntry>{param, dmat->Info(), 4, sampler,
@ -102,12 +102,22 @@ TEST(HistEvaluator, Apply) {
CPUExpandEntry entry{0, 0, 10.0f}; CPUExpandEntry entry{0, 0, 10.0f};
entry.split.left_sum = GradStats{0.4, 0.6f}; entry.split.left_sum = GradStats{0.4, 0.6f};
entry.split.right_sum = GradStats{0.5, 0.7f}; entry.split.right_sum = GradStats{0.5, 0.5f};
evaluator_.ApplyTreeSplit(entry, &tree); evaluator_.ApplyTreeSplit(entry, &tree);
ASSERT_EQ(tree.NumExtraNodes(), 2); ASSERT_EQ(tree.NumExtraNodes(), 2);
ASSERT_EQ(tree.Stat(tree[0].LeftChild()).sum_hess, 0.6f); ASSERT_EQ(tree.Stat(tree[0].LeftChild()).sum_hess, 0.6f);
ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f); ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.5f);
{
RegTree tree;
entry.split.is_cat = true;
entry.split.split_value = 1.0;
evaluator_.ApplyTreeSplit(entry, &tree);
auto l = entry.split.left_sum;
ASSERT_NEAR(tree[1].LeafValue(), -l.sum_grad / l.sum_hess * param.learning_rate, kRtEps);
ASSERT_NEAR(tree[2].LeafValue(), -param.learning_rate, kRtEps);
}
} }
TEST_F(TestPartitionBasedSplit, CPUHist) { TEST_F(TestPartitionBasedSplit, CPUHist) {

View File

@ -1,26 +1,14 @@
/*! /*!
* Copyright 2021 XGBoost contributors * Copyright 2021-2022, XGBoost contributors.
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "../../../src/tree/updater_approx.h" #include "../../../src/tree/updater_approx.h"
#include "../helpers.h" #include "../helpers.h"
#include "test_partitioner.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
namespace {
void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
tree->ExpandNode(
/*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f,
/*right_sum=*/0.0f);
candidates->front().split.split_value = split_value;
candidates->front().split.sindex = 0;
candidates->front().split.sindex |= (1U << 31);
}
} // anonymous namespace
TEST(Approx, Partitioner) { TEST(Approx, Partitioner) {
size_t n_samples = 1024, n_features = 1, base_rowid = 0; size_t n_samples = 1024, n_features = 1, base_rowid = 0;
ApproxRowPartitioner partitioner{n_samples, base_rowid}; ApproxRowPartitioner partitioner{n_samples, base_rowid};

View File

@ -0,0 +1,21 @@
/*!
* Copyright 2021-2022, XGBoost contributors.
*/
#include <xgboost/tree_model.h>
#include <vector>
#include "../../../src/tree/hist/expand_entry.h"
namespace xgboost {
namespace tree {
inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
tree->ExpandNode(
/*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f,
/*right_sum=*/0.0f);
candidates->front().split.split_value = split_value;
candidates->front().split.sindex = 0;
candidates->front().split.sindex |= (1U << 31);
}
} // namespace tree
} // namespace xgboost

View File

@ -1,18 +1,19 @@
/*! /*!
* Copyright 2018-2022 by XGBoost Contributors * Copyright 2018-2022 by XGBoost Contributors
*/ */
#include <gtest/gtest.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
#include <vector>
#include <string> #include <string>
#include <vector>
#include "../helpers.h"
#include "../../../src/tree/param.h" #include "../../../src/tree/param.h"
#include "../../../src/tree/updater_quantile_hist.h"
#include "../../../src/tree/split_evaluator.h" #include "../../../src/tree/split_evaluator.h"
#include "../../../src/tree/updater_quantile_hist.h"
#include "../helpers.h"
#include "test_partitioner.h"
#include "xgboost/data.h" #include "xgboost/data.h"
namespace xgboost { namespace xgboost {
@ -94,130 +95,6 @@ class QuantileHistMock : public QuantileHistMaker {
} }
} }
} }
void TestInitDataSampling(const GHistIndexMatrix& gmat,
std::vector<GradientPair>* gpair,
DMatrix* p_fmat,
const RegTree& tree) {
// check SimpleSkip
size_t initial_seed = 777;
std::linear_congruential_engine<std::uint_fast64_t, 16807, 0,
static_cast<uint64_t>(1) << 63 > eng_first(initial_seed);
for (size_t i = 0; i < 100; ++i) {
eng_first();
}
uint64_t initial_seed_th = RandomReplace::SimpleSkip(100, initial_seed, 16807, RandomReplace::kMod);
std::linear_congruential_engine<std::uint_fast64_t, RandomReplace::kBase, 0,
RandomReplace::kMod > eng_second(initial_seed_th);
ASSERT_EQ(eng_first(), eng_second());
const size_t nthreads = omp_get_num_threads();
// save state of global rng engine
auto initial_rnd = common::GlobalRandom();
std::vector<size_t> unused_rows_cpy = this->unused_rows_;
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data());
std::vector<size_t> unused_row_indices_initial = this->unused_rows_;
ASSERT_EQ(row_indices_initial.size(), p_fmat->Info().num_row_);
auto check_each_row_occurs_in_one_of_arrays = [](const std::vector<size_t>& first,
const std::vector<size_t>& second,
size_t nrows) {
ASSERT_EQ(first.size(), nrows);
ASSERT_EQ(second.size(), 0);
};
check_each_row_occurs_in_one_of_arrays(row_indices_initial, unused_row_indices_initial,
p_fmat->Info().num_row_);
for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) {
omp_set_num_threads(i_nthreads);
// return initial state of global rng engine
common::GlobalRandom() = initial_rnd;
this->unused_rows_ = unused_rows_cpy;
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
std::vector<size_t>& row_indices = *(this->row_set_collection_.Data());
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
ASSERT_EQ(row_indices_initial[i], row_indices[i]);
}
std::vector<size_t>& unused_row_indices = this->unused_rows_;
ASSERT_EQ(unused_row_indices_initial.size(), unused_row_indices.size());
for (size_t i = 0; i < unused_row_indices_initial.size(); ++i) {
ASSERT_EQ(unused_row_indices_initial[i], unused_row_indices[i]);
}
check_each_row_occurs_in_one_of_arrays(row_indices, unused_row_indices,
p_fmat->Info().num_row_);
}
omp_set_num_threads(nthreads);
}
void TestApplySplit(const RegTree& tree) {
std::vector<GradientPair> row_gpairs =
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
int32_t constexpr kMaxBins = 4;
// try out different sparsity to get different number of missing values
for (double sparsity : {0.0, 0.1, 0.2}) {
// kNRows samples with kNCols features
auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix();
float sparse_th = 0.0;
GHistIndexMatrix gmat{dmat.get(), kMaxBins, sparse_th, false, common::OmpGetNumThreads(0)};
ColumnMatrix cm;
// treat everything as dense, as this is what we intend to test here
cm.Init(gmat, sparse_th, common::OmpGetNumThreads(0));
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
const size_t num_row = dmat->Info().num_row_;
// split by feature 0
const size_t bin_id_min = gmat.cut.Ptrs()[0];
const size_t bin_id_max = gmat.cut.Ptrs()[1];
// attempt to split at different bins
for (size_t split = 0; split < 4; split++) {
size_t left_cnt = 0, right_cnt = 0;
// manually compute how many samples go left or right
for (size_t rid = 0; rid < num_row; ++rid) {
for (size_t offset = gmat.row_ptr[rid]; offset < gmat.row_ptr[rid + 1]; ++offset) {
const size_t bin_id = gmat.index[offset];
if (bin_id >= bin_id_min && bin_id < bin_id_max) {
if (bin_id <= split) {
left_cnt++;
} else {
right_cnt++;
}
}
}
}
// if any were missing due to sparsity, we add them to the left or to the right
size_t missing = kNRows - left_cnt - right_cnt;
if (tree[0].DefaultLeft()) {
left_cnt += missing;
} else {
right_cnt += missing;
}
// have one node with kNRows (=8 at the moment) rows, just one task
RealImpl::partition_builder_.Init(1, 1, [&](size_t node_in_set) {
return 1;
});
const size_t task_id = RealImpl::partition_builder_.GetTaskIdx(0, 0);
RealImpl::partition_builder_.AllocateForTask(task_id);
if (cm.AnyMissing()) {
RealImpl::partition_builder_.template Partition<uint8_t, true>(0, 0, common::Range1d(0, kNRows),
split, cm, tree, this->row_set_collection_[0].begin);
} else {
RealImpl::partition_builder_.template Partition<uint8_t, false>(0, 0, common::Range1d(0, kNRows),
split, cm, tree, this->row_set_collection_[0].begin);
}
RealImpl::partition_builder_.CalculateRowOffsets();
ASSERT_EQ(RealImpl::partition_builder_.GetNLeftElems(0), left_cnt);
ASSERT_EQ(RealImpl::partition_builder_.GetNRightElems(0), right_cnt);
}
}
}
}; };
int static constexpr kNRows = 8, kNCols = 16; int static constexpr kNRows = 8, kNCols = 16;
@ -262,33 +139,6 @@ class QuantileHistMock : public QuantileHistMaker {
float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree); float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
} }
} }
void TestInitDataSampling() {
int32_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)};
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
if (double_builder_) {
double_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree);
} else {
float_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree);
}
}
void TestApplySplit() {
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
if (double_builder_) {
double_builder_->TestApplySplit(tree);
} else {
float_builder_->TestApplySplit(tree);
}
}
}; };
TEST(QuantileHist, InitData) { TEST(QuantileHist, InitData) {
@ -301,30 +151,62 @@ TEST(QuantileHist, InitData) {
maker_float.TestInitData(); maker_float.TestInitData();
} }
TEST(QuantileHist, InitDataSampling) { TEST(QuantileHist, Partitioner) {
const float subsample = 0.5; size_t n_samples = 1024, n_features = 1, base_rowid = 0;
std::vector<std::pair<std::string, std::string>> cfg GenericParameter ctx;
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}, ctx.InitAllowUnknown(Args{});
{"subsample", std::to_string(subsample)}};
QuantileHistMock maker(cfg);
maker.TestInitDataSampling();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestInitDataSampling();
}
TEST(QuantileHist, ApplySplit) { HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
std::vector<std::pair<std::string, std::string>> cfg ASSERT_EQ(partitioner.base_rowid, base_rowid);
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}, ASSERT_EQ(partitioner.Size(), 1);
{"split_evaluator", "elastic_net"}, ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
{"reg_lambda", "0"}, {"reg_alpha", "0"}, {"max_delta_step", "0"},
{"min_child_weight", "0"}};
QuantileHistMock maker(cfg);
maker.TestApplySplit();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestApplySplit();
}
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
auto grad = GenerateRandomGradients(n_samples);
std::vector<float> hess(grad.Size());
std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(),
[](auto gpair) { return gpair.GetHess(); });
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, 0.5})) {
bst_feature_t const split_ind = 0;
common::ColumnMatrix column_indices;
column_indices.Init(page, 0.5, ctx.Threads());
{
auto min_value = page.cut.MinValues()[split_ind];
RegTree tree;
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
GetSplit(&tree, min_value, &candidates);
partitioner.UpdatePosition<false, true>(&ctx, page, column_indices, candidates, &tree);
ASSERT_EQ(partitioner.Size(), 3);
ASSERT_EQ(partitioner[1].Size(), 0);
ASSERT_EQ(partitioner[2].Size(), n_samples);
}
{
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
auto ptr = page.cut.Ptrs()[split_ind + 1];
float split_value = page.cut.Values().at(ptr / 2);
RegTree tree;
GetSplit(&tree, split_value, &candidates);
auto left_nidx = tree[RegTree::kRoot].LeftChild();
partitioner.UpdatePosition<false, true>(&ctx, page, column_indices, candidates, &tree);
auto elem = partitioner[left_nidx];
ASSERT_LT(elem.Size(), n_samples);
ASSERT_GT(elem.Size(), 1);
for (auto it = elem.begin; it != elem.end; ++it) {
auto value = page.cut.Values().at(page.index[*it]);
ASSERT_LE(value, split_value);
}
auto right_nidx = tree[RegTree::kRoot].RightChild();
elem = partitioner[right_nidx];
for (auto it = elem.begin; it != elem.end; ++it) {
auto value = page.cut.Values().at(page.index[*it]);
ASSERT_GT(value, split_value) << *it;
}
}
}
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -245,3 +245,4 @@ class TestTreeMethod:
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats): def test_categorical(self, rows, cols, rounds, cats):
self.run_categorical_basic(rows, cols, rounds, cats, "approx") self.run_categorical_basic(rows, cols, rounds, cats, "approx")
self.run_categorical_basic(rows, cols, rounds, cats, "hist")