Support categorical data for hist. (#7695)
* Extract partitioner from hist. * Implement categorical data support by passing the gradient index directly into the partitioner. * Organize/update document. * Remove code for negative hessian.
This commit is contained in:
parent
f60d95b0ba
commit
83a66b4994
@ -244,9 +244,6 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
|
|||||||
|
|
||||||
- Use single precision to build histograms instead of double precision.
|
- Use single precision to build histograms instead of double precision.
|
||||||
|
|
||||||
Additional parameters for ``approx`` and ``gpu_hist`` tree method
|
|
||||||
=================================================================
|
|
||||||
|
|
||||||
* ``max_cat_to_onehot``
|
* ``max_cat_to_onehot``
|
||||||
|
|
||||||
.. versionadded:: 1.6
|
.. versionadded:: 1.6
|
||||||
@ -256,8 +253,8 @@ Additional parameters for ``approx`` and ``gpu_hist`` tree method
|
|||||||
- A threshold for deciding whether XGBoost should use one-hot encoding based split for
|
- A threshold for deciding whether XGBoost should use one-hot encoding based split for
|
||||||
categorical data. When number of categories is lesser than the threshold then one-hot
|
categorical data. When number of categories is lesser than the threshold then one-hot
|
||||||
encoding is chosen, otherwise the categories will be partitioned into children nodes.
|
encoding is chosen, otherwise the categories will be partitioned into children nodes.
|
||||||
Only relevant for regression and binary classification. Also, `approx` or `gpu_hist`
|
Only relevant for regression and binary classification. Also, ``exact`` tree method is
|
||||||
tree method is required.
|
not supported
|
||||||
|
|
||||||
Additional parameters for Dart Booster (``booster=dart``)
|
Additional parameters for Dart Booster (``booster=dart``)
|
||||||
=========================================================
|
=========================================================
|
||||||
|
|||||||
@ -4,16 +4,16 @@ Categorical Data
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
As of XGBoost 1.6, the feature is highly experimental and has limited features
|
As of XGBoost 1.6, the feature is experimental and has limited features
|
||||||
|
|
||||||
Starting from version 1.5, XGBoost has experimental support for categorical data available
|
Starting from version 1.5, XGBoost has experimental support for categorical data available
|
||||||
for public testing. At the moment, the support is implemented as one-hot encoding based
|
for public testing. For numerical data, the split condition is defined as :math:`value <
|
||||||
categorical tree splits. For numerical data, the split condition is defined as
|
threshold`, while for categorical data the split is defined depending on whether
|
||||||
:math:`value < threshold`, while for categorical data the split is defined as :math:`value
|
partitioning or onehot encoding is used. For partition-based splits, the splits are
|
||||||
== category` and ``category`` is a discrete value. More advanced categorical split
|
specified as :math:`value \in categories`, where ``categories`` is the set of categories
|
||||||
strategy is planned for future releases and this tutorial details how to inform XGBoost
|
in one feature. If onehot encoding is used instead, then the split is defined as
|
||||||
about the data type. Also, the current support for training is limited to ``gpu_hist``
|
:math:`value == category`. More advanced categorical split strategy is planned for future
|
||||||
tree method.
|
releases and this tutorial details how to inform XGBoost about the data type.
|
||||||
|
|
||||||
************************************
|
************************************
|
||||||
Training with scikit-learn Interface
|
Training with scikit-learn Interface
|
||||||
@ -35,13 +35,13 @@ parameter ``enable_categorical``:
|
|||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
# Only gpu_hist is supported for categorical data as mentioned previously
|
# Supported tree methods are `gpu_hist`, `approx`, and `hist`.
|
||||||
clf = xgb.XGBClassifier(
|
clf = xgb.XGBClassifier(
|
||||||
tree_method="gpu_hist", enable_categorical=True, use_label_encoder=False
|
tree_method="gpu_hist", enable_categorical=True, use_label_encoder=False
|
||||||
)
|
)
|
||||||
# X is the dataframe we created in previous snippet
|
# X is the dataframe we created in previous snippet
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
# Must use JSON for serialization, otherwise the information is lost
|
# Must use JSON/UBJSON for serialization, otherwise the information is lost.
|
||||||
clf.save_model("categorical-model.json")
|
clf.save_model("categorical-model.json")
|
||||||
|
|
||||||
|
|
||||||
@ -60,11 +60,37 @@ can plot the model and calculate the global feature importance:
|
|||||||
|
|
||||||
|
|
||||||
The ``scikit-learn`` interface from dask is similar to single node version. The basic
|
The ``scikit-learn`` interface from dask is similar to single node version. The basic
|
||||||
idea is create dataframe with category feature type, and tell XGBoost to use ``gpu_hist``
|
idea is create dataframe with category feature type, and tell XGBoost to use it by setting
|
||||||
with parameter ``enable_categorical``. See :ref:`sphx_glr_python_examples_categorical.py`
|
the ``enable_categorical`` parameter. See :ref:`sphx_glr_python_examples_categorical.py`
|
||||||
for a worked example of using categorical data with ``scikit-learn`` interface. A
|
for a worked example of using categorical data with ``scikit-learn`` interface with
|
||||||
comparison between using one-hot encoded data and XGBoost's categorical data support can
|
one-hot encoding. A comparison between using one-hot encoded data and XGBoost's
|
||||||
be found :ref:`sphx_glr_python_examples_cat_in_the_dat.py`.
|
categorical data support can be found :ref:`sphx_glr_python_examples_cat_in_the_dat.py`.
|
||||||
|
|
||||||
|
|
||||||
|
********************
|
||||||
|
Optimal Partitioning
|
||||||
|
********************
|
||||||
|
|
||||||
|
.. versionadded:: 1.6
|
||||||
|
|
||||||
|
Optimal partitioning is a technique for partitioning the categorical predictors for each
|
||||||
|
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
|
||||||
|
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
|
||||||
|
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
|
||||||
|
<#references>`__ brought it to the context of gradient boosting trees and now is also
|
||||||
|
adopted in XGBoost as an optional feature for handling categorical splits. More
|
||||||
|
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
|
||||||
|
partition a set of discrete values into groups based on the distances between a measure of
|
||||||
|
these values, one only needs to look at sorted partitions instead of enumerating all
|
||||||
|
possible permutations. In the context of decision trees, the discrete values are
|
||||||
|
categories, and the measure is the output leaf value. Intuitively, we want to group the
|
||||||
|
categories that output similar leaf values. During split finding, we first sort the
|
||||||
|
gradient histogram to prepare the contiguous partitions then enumerate the splits
|
||||||
|
according to these sorted values. One of the related parameters for XGBoost is
|
||||||
|
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
|
||||||
|
used for each feature, see :doc:`/parameter` for details. When objective is not
|
||||||
|
regression or binary classification, XGBoost will fallback to using onehot encoding
|
||||||
|
instead.
|
||||||
|
|
||||||
|
|
||||||
**********************
|
**********************
|
||||||
@ -82,7 +108,7 @@ categorical data, we need to pass the similar parameter to :class:`DMatrix
|
|||||||
|
|
||||||
# X is a dataframe we created in previous snippet
|
# X is a dataframe we created in previous snippet
|
||||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||||
booster = xgb.train({"tree_method": "gpu_hist"}, Xy)
|
booster = xgb.train({"tree_method": "hist", "max_cat_to_onehot": 5}, Xy)
|
||||||
# Must use JSON for serialization, otherwise the information is lost
|
# Must use JSON for serialization, otherwise the information is lost
|
||||||
booster.save_model("categorical-model.json")
|
booster.save_model("categorical-model.json")
|
||||||
|
|
||||||
@ -109,30 +135,7 @@ types by using the ``feature_types`` parameter in :class:`DMatrix <xgboost.DMatr
|
|||||||
|
|
||||||
For numerical data, the feature type can be ``"q"`` or ``"float"``, while for categorical
|
For numerical data, the feature type can be ``"q"`` or ``"float"``, while for categorical
|
||||||
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
|
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
|
||||||
:class:`dask.Array <dask.Array>` can also be used as categorical data.
|
:class:`dask.Array <dask.Array>` can also be used for categorical data.
|
||||||
|
|
||||||
********************
|
|
||||||
Optimal Partitioning
|
|
||||||
********************
|
|
||||||
|
|
||||||
.. versionadded:: 1.6
|
|
||||||
|
|
||||||
Optimal partitioning is a technique for partitioning the categorical predictors for each
|
|
||||||
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
|
|
||||||
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
|
|
||||||
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
|
|
||||||
<#references>`__ brought it to the context of gradient boosting trees and now is also
|
|
||||||
adopted in XGBoost as an optional feature for handling categorical splits. More
|
|
||||||
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
|
|
||||||
partition a set of discrete values into groups based on the distances between a measure of
|
|
||||||
these values, one only needs to look at sorted partitions instead of enumerating all
|
|
||||||
possible permutations. In the context of decision trees, the discrete values are
|
|
||||||
categories, and the measure is the output leaf value. Intuitively, we want to group the
|
|
||||||
categories that output similar leaf values. During split finding, we first sort the
|
|
||||||
gradient histogram to prepare the contiguous partitions then enumerate the splits
|
|
||||||
according to these sorted values. One of the related parameters for XGBoost is
|
|
||||||
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
|
|
||||||
used for each feature, see :doc:`/parameter` for details.
|
|
||||||
|
|
||||||
*************
|
*************
|
||||||
Miscellaneous
|
Miscellaneous
|
||||||
|
|||||||
@ -604,6 +604,16 @@ class RegTree : public Model {
|
|||||||
*/
|
*/
|
||||||
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
|
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
|
||||||
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
|
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
|
||||||
|
/*!
|
||||||
|
* \brief Get the bit storage for categories
|
||||||
|
*/
|
||||||
|
common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
|
||||||
|
auto node_ptr = GetCategoriesMatrix().node_ptr;
|
||||||
|
auto categories = GetCategoriesMatrix().categories;
|
||||||
|
auto segment = node_ptr[nidx];
|
||||||
|
auto node_cats = categories.subspan(segment.beg, segment.size);
|
||||||
|
return node_cats;
|
||||||
|
}
|
||||||
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
|
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
|
||||||
|
|
||||||
// The fields of split_categories_segments_[i] are set such that
|
// The fields of split_categories_segments_[i] are set such that
|
||||||
|
|||||||
@ -582,10 +582,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
|
.. note:: This parameter is experimental
|
||||||
|
|
||||||
Experimental support of specializing for categorical features. Do not set
|
Experimental support of specializing for categorical features. Do not set
|
||||||
to True unless you are interested in development. Currently it's only
|
to True unless you are interested in development. Also, JSON/UBJSON
|
||||||
available for `gpu_hist` and `approx` tree methods. Also, JSON/UBJSON
|
serialization format is required.
|
||||||
serialization format is required. (XGBoost 1.6 for approx)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if group is not None and qid is not None:
|
if group is not None and qid is not None:
|
||||||
|
|||||||
@ -206,10 +206,11 @@ __model_doc = f'''
|
|||||||
|
|
||||||
.. versionadded:: 1.5.0
|
.. versionadded:: 1.5.0
|
||||||
|
|
||||||
Experimental support for categorical data. Do not set to true unless you are
|
.. note:: This parameter is experimental
|
||||||
interested in development. Only valid when `gpu_hist` or `approx` is used along
|
|
||||||
with dataframe as input. Also, JSON/UBJSON serialization format is
|
Experimental support for categorical data. When enabled, cudf/pandas.DataFrame
|
||||||
required. (XGBoost 1.6 for approx)
|
should be used to specify categorical data type. Also, JSON/UBJSON
|
||||||
|
serialization format is required.
|
||||||
|
|
||||||
max_cat_to_onehot : Optional[int]
|
max_cat_to_onehot : Optional[int]
|
||||||
|
|
||||||
@ -220,9 +221,8 @@ __model_doc = f'''
|
|||||||
A threshold for deciding whether XGBoost should use one-hot encoding based split
|
A threshold for deciding whether XGBoost should use one-hot encoding based split
|
||||||
for categorical data. When number of categories is lesser than the threshold
|
for categorical data. When number of categories is lesser than the threshold
|
||||||
then one-hot encoding is chosen, otherwise the categories will be partitioned
|
then one-hot encoding is chosen, otherwise the categories will be partitioned
|
||||||
into children nodes. Only relevant for regression and binary
|
into children nodes. Only relevant for regression and binary classification.
|
||||||
classification. Also, ``approx`` or ``gpu_hist`` tree method is required. See
|
See :doc:`Categorical Data </tutorials/categorical>` for details.
|
||||||
:doc:`Categorical Data </tutorials/categorical>` for details.
|
|
||||||
|
|
||||||
eval_metric : Optional[Union[str, List[str], Callable]]
|
eval_metric : Optional[Union[str, List[str], Callable]]
|
||||||
|
|
||||||
@ -846,7 +846,8 @@ class XGBModel(XGBModelBase):
|
|||||||
callbacks = self.callbacks if self.callbacks is not None else callbacks
|
callbacks = self.callbacks if self.callbacks is not None else callbacks
|
||||||
|
|
||||||
tree_method = params.get("tree_method", None)
|
tree_method = params.get("tree_method", None)
|
||||||
if self.enable_categorical and tree_method not in ("gpu_hist", "approx"):
|
cat_support = {"gpu_hist", "approx", "hist"}
|
||||||
|
if self.enable_categorical and tree_method not in cat_support:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Experimental support for categorical data is not implemented for"
|
"Experimental support for categorical data is not implemented for"
|
||||||
" current tree method yet."
|
" current tree method yet."
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2021 by Contributors
|
* Copyright 2021-2022 by Contributors
|
||||||
* \file row_set.h
|
* \file row_set.h
|
||||||
* \brief Quick Utility to compute subset of rows
|
* \brief Quick Utility to compute subset of rows
|
||||||
* \author Philip Cho, Tianqi Chen
|
* \author Philip Cho, Tianqi Chen
|
||||||
@ -8,12 +8,15 @@
|
|||||||
#define XGBOOST_COMMON_PARTITION_BUILDER_H_
|
#define XGBOOST_COMMON_PARTITION_BUILDER_H_
|
||||||
|
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
|
||||||
#include <utility>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "categorical.h"
|
||||||
|
#include "column_matrix.h"
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
#include "../common/column_matrix.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
@ -46,18 +49,24 @@ class PartitionBuilder {
|
|||||||
// on comparison of indexes values (idx_span) and split point (split_cond)
|
// on comparison of indexes values (idx_span) and split point (split_cond)
|
||||||
// Handle dense columns
|
// Handle dense columns
|
||||||
// Analog of std::stable_partition, but in no-inplace manner
|
// Analog of std::stable_partition, but in no-inplace manner
|
||||||
template <bool default_left, bool any_missing, typename ColumnType>
|
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
|
||||||
inline std::pair<size_t, size_t> PartitionKernel(const ColumnType& column,
|
inline std::pair<size_t, size_t> PartitionKernel(const ColumnType& column,
|
||||||
common::Span<const size_t> rid_span, const int32_t split_cond,
|
common::Span<const size_t> row_indices,
|
||||||
common::Span<size_t> left_part, common::Span<size_t> right_part) {
|
common::Span<size_t> left_part,
|
||||||
|
common::Span<size_t> right_part,
|
||||||
|
size_t base_rowid, Predicate&& pred) {
|
||||||
size_t* p_left_part = left_part.data();
|
size_t* p_left_part = left_part.data();
|
||||||
size_t* p_right_part = right_part.data();
|
size_t* p_right_part = right_part.data();
|
||||||
size_t nleft_elems = 0;
|
size_t nleft_elems = 0;
|
||||||
size_t nright_elems = 0;
|
size_t nright_elems = 0;
|
||||||
auto state = column.GetInitialState(rid_span.front());
|
auto state = column.GetInitialState(row_indices.front() - base_rowid);
|
||||||
|
|
||||||
for (auto rid : rid_span) {
|
auto p_row_indices = row_indices.data();
|
||||||
const int32_t bin_id = column.GetBinIdx(rid, &state);
|
auto n_samples = row_indices.size();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n_samples; ++i) {
|
||||||
|
auto rid = p_row_indices[i];
|
||||||
|
const int32_t bin_id = column.GetBinIdx(rid - base_rowid, &state);
|
||||||
if (any_missing && bin_id == ColumnType::kMissingId) {
|
if (any_missing && bin_id == ColumnType::kMissingId) {
|
||||||
if (default_left) {
|
if (default_left) {
|
||||||
p_left_part[nleft_elems++] = rid;
|
p_left_part[nleft_elems++] = rid;
|
||||||
@ -65,7 +74,7 @@ class PartitionBuilder {
|
|||||||
p_right_part[nright_elems++] = rid;
|
p_right_part[nright_elems++] = rid;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (bin_id <= split_cond) {
|
if (pred(rid, bin_id)) {
|
||||||
p_left_part[nleft_elems++] = rid;
|
p_left_part[nleft_elems++] = rid;
|
||||||
} else {
|
} else {
|
||||||
p_right_part[nright_elems++] = rid;
|
p_right_part[nright_elems++] = rid;
|
||||||
@ -95,41 +104,66 @@ class PartitionBuilder {
|
|||||||
return {nleft_elems, nright_elems};
|
return {nleft_elems, nright_elems};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename BinIdxType, bool any_missing>
|
template <typename BinIdxType, bool any_missing, bool any_cat>
|
||||||
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
|
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
|
||||||
const int32_t split_cond,
|
const int32_t split_cond, GHistIndexMatrix const& gmat,
|
||||||
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
|
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
|
||||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||||
common::Span<size_t> left = GetLeftBuffer(node_in_set,
|
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
||||||
range.begin(), range.end());
|
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
||||||
common::Span<size_t> right = GetRightBuffer(node_in_set,
|
|
||||||
range.begin(), range.end());
|
|
||||||
const bst_uint fid = tree[nid].SplitIndex();
|
const bst_uint fid = tree[nid].SplitIndex();
|
||||||
const bool default_left = tree[nid].DefaultLeft();
|
const bool default_left = tree[nid].DefaultLeft();
|
||||||
const auto column_ptr = column_matrix.GetColumn<BinIdxType, any_missing>(fid);
|
const auto column_ptr = column_matrix.GetColumn<BinIdxType, any_missing>(fid);
|
||||||
|
|
||||||
std::pair<size_t, size_t> child_nodes_sizes;
|
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||||
|
auto node_cats = tree.NodeCats(nid);
|
||||||
|
|
||||||
|
auto const& index = gmat.index;
|
||||||
|
auto const& cut_values = gmat.cut.Values();
|
||||||
|
auto const& cut_ptrs = gmat.cut.Ptrs();
|
||||||
|
|
||||||
|
auto pred = [&](auto ridx, auto bin_id) {
|
||||||
|
if (any_cat && is_cat) {
|
||||||
|
auto begin = gmat.RowIdx(ridx);
|
||||||
|
auto end = gmat.RowIdx(ridx + 1);
|
||||||
|
auto f_begin = cut_ptrs[fid];
|
||||||
|
auto f_end = cut_ptrs[fid + 1];
|
||||||
|
// bypassing the column matrix as we need the cut value instead of bin idx for categorical
|
||||||
|
// features.
|
||||||
|
auto gidx = BinarySearchBin(begin, end, index, f_begin, f_end);
|
||||||
|
bool go_left;
|
||||||
|
if (gidx == -1) {
|
||||||
|
go_left = default_left;
|
||||||
|
} else {
|
||||||
|
go_left = Decision(node_cats, cut_values[gidx], default_left);
|
||||||
|
}
|
||||||
|
return go_left;
|
||||||
|
} else {
|
||||||
|
return bin_id <= split_cond;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::pair<size_t, size_t> child_nodes_sizes;
|
||||||
if (column_ptr->GetType() == xgboost::common::kDenseColumn) {
|
if (column_ptr->GetType() == xgboost::common::kDenseColumn) {
|
||||||
const common::DenseColumn<BinIdxType, any_missing>& column =
|
const common::DenseColumn<BinIdxType, any_missing>& column =
|
||||||
static_cast<const common::DenseColumn<BinIdxType, any_missing>& >(*(column_ptr.get()));
|
static_cast<const common::DenseColumn<BinIdxType, any_missing>& >(*(column_ptr.get()));
|
||||||
if (default_left) {
|
if (default_left) {
|
||||||
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span,
|
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
|
||||||
split_cond, left, right);
|
gmat.base_rowid, pred);
|
||||||
} else {
|
} else {
|
||||||
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span,
|
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
|
||||||
split_cond, left, right);
|
gmat.base_rowid, pred);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(any_missing, true);
|
CHECK_EQ(any_missing, true);
|
||||||
const common::SparseColumn<BinIdxType>& column
|
const common::SparseColumn<BinIdxType>& column
|
||||||
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
|
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
|
||||||
if (default_left) {
|
if (default_left) {
|
||||||
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span,
|
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
|
||||||
split_cond, left, right);
|
gmat.base_rowid, pred);
|
||||||
} else {
|
} else {
|
||||||
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span,
|
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
|
||||||
split_cond, left, right);
|
gmat.base_rowid, pred);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -275,9 +275,6 @@ class MemStackAllocator {
|
|||||||
T& operator[](size_t i) { return ptr_[i]; }
|
T& operator[](size_t i) { return ptr_[i]; }
|
||||||
T const& operator[](size_t i) const { return ptr_[i]; }
|
T const& operator[](size_t i) const { return ptr_[i]; }
|
||||||
|
|
||||||
// FIXME(jiamingy): Remove this once we merge partitioner cleanup for hist.
|
|
||||||
auto Get() { return ptr_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
T* ptr_ = nullptr;
|
T* ptr_ = nullptr;
|
||||||
size_t required_size_;
|
size_t required_size_;
|
||||||
|
|||||||
@ -288,10 +288,10 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
|||||||
auto base_weight =
|
auto base_weight =
|
||||||
evaluator.CalcWeight(candidate.nid, param_, GradStats{parent_sum});
|
evaluator.CalcWeight(candidate.nid, param_, GradStats{parent_sum});
|
||||||
|
|
||||||
auto left_weight = evaluator.CalcWeight(
|
auto left_weight =
|
||||||
candidate.nid, param_, GradStats{candidate.split.left_sum});
|
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.left_sum});
|
||||||
auto right_weight = evaluator.CalcWeight(
|
auto right_weight =
|
||||||
candidate.nid, param_, GradStats{candidate.split.right_sum});
|
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum});
|
||||||
|
|
||||||
if (candidate.split.is_cat) {
|
if (candidate.split.is_cat) {
|
||||||
std::vector<uint32_t> split_cats;
|
std::vector<uint32_t> split_cats;
|
||||||
@ -308,11 +308,11 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
|||||||
split_cats = candidate.split.cat_bits;
|
split_cats = candidate.split.cat_bits;
|
||||||
common::CatBitField cat_bits{split_cats};
|
common::CatBitField cat_bits{split_cats};
|
||||||
}
|
}
|
||||||
|
|
||||||
tree.ExpandCategorical(
|
tree.ExpandCategorical(
|
||||||
candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(),
|
candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(),
|
||||||
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
|
base_weight, left_weight * param_.learning_rate, right_weight * param_.learning_rate,
|
||||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.left_sum.GetHess(),
|
||||||
|
candidate.split.right_sum.GetHess());
|
||||||
} else {
|
} else {
|
||||||
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
|
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
|
||||||
candidate.split.DefaultLeft(), base_weight,
|
candidate.split.DefaultLeft(), base_weight,
|
||||||
|
|||||||
@ -124,11 +124,12 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
|||||||
nodes_for_subtraction_trick_.clear();
|
nodes_for_subtraction_trick_.clear();
|
||||||
nodes_for_explicit_hist_build_.push_back(node);
|
nodes_for_explicit_hist_build_.push_back(node);
|
||||||
|
|
||||||
|
auto const& row_set_collection = partitioner_.front().Partitions();
|
||||||
size_t page_id = 0;
|
size_t page_id = 0;
|
||||||
for (auto const& gidx :
|
for (auto const& gidx :
|
||||||
p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||||
this->histogram_builder_->BuildHist(
|
this->histogram_builder_->BuildHist(
|
||||||
page_id, gidx, p_tree, row_set_collection_,
|
page_id, gidx, p_tree, row_set_collection,
|
||||||
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h);
|
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h);
|
||||||
++page_id;
|
++page_id;
|
||||||
}
|
}
|
||||||
@ -149,7 +150,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
|||||||
grad_stat.Add(et.GetGrad(), et.GetHess());
|
grad_stat.Add(et.GetGrad(), et.GetHess());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const common::RowSetCollection::Elem e = row_set_collection_[nid];
|
const common::RowSetCollection::Elem e = row_set_collection[nid];
|
||||||
for (const size_t *it = e.begin; it < e.end; ++it) {
|
for (const size_t *it = e.begin; it < e.end; ++it) {
|
||||||
grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess());
|
grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess());
|
||||||
}
|
}
|
||||||
@ -204,6 +205,7 @@ void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
|
|||||||
const std::vector<CPUExpandEntry> &nodes_for_apply_split,
|
const std::vector<CPUExpandEntry> &nodes_for_apply_split,
|
||||||
std::vector<CPUExpandEntry> *nodes_to_evaluate, RegTree *p_tree) {
|
std::vector<CPUExpandEntry> *nodes_to_evaluate, RegTree *p_tree) {
|
||||||
builder_monitor_.Start("SplitSiblings");
|
builder_monitor_.Start("SplitSiblings");
|
||||||
|
auto const& row_set_collection = this->partitioner_.front().Partitions();
|
||||||
for (auto const& entry : nodes_for_apply_split) {
|
for (auto const& entry : nodes_for_apply_split) {
|
||||||
int nid = entry.nid;
|
int nid = entry.nid;
|
||||||
|
|
||||||
@ -213,7 +215,7 @@ void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
|
|||||||
const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0);
|
const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0);
|
||||||
nodes_to_evaluate->push_back(left_node);
|
nodes_to_evaluate->push_back(left_node);
|
||||||
nodes_to_evaluate->push_back(right_node);
|
nodes_to_evaluate->push_back(right_node);
|
||||||
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
|
if (row_set_collection[cleft].Size() < row_set_collection[cright].Size()) {
|
||||||
nodes_for_explicit_hist_build_.push_back(left_node);
|
nodes_for_explicit_hist_build_.push_back(left_node);
|
||||||
nodes_for_subtraction_trick_.push_back(right_node);
|
nodes_for_subtraction_trick_.push_back(right_node);
|
||||||
} else {
|
} else {
|
||||||
@ -253,16 +255,23 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
|||||||
AddSplitsToTree(expand, p_tree, &num_leaves, &nodes_for_apply_split);
|
AddSplitsToTree(expand, p_tree, &num_leaves, &nodes_for_apply_split);
|
||||||
|
|
||||||
if (nodes_for_apply_split.size() != 0) {
|
if (nodes_for_apply_split.size() != 0) {
|
||||||
ApplySplit<any_missing>(nodes_for_apply_split, gmat, column_matrix, p_tree);
|
HistRowPartitioner &partitioner = this->partitioner_.front();
|
||||||
|
if (gmat.cut.HasCategorical()) {
|
||||||
|
partitioner.UpdatePosition<any_missing, true>(this->ctx_, gmat, column_matrix,
|
||||||
|
nodes_for_apply_split, p_tree);
|
||||||
|
} else {
|
||||||
|
partitioner.UpdatePosition<any_missing, false>(this->ctx_, gmat, column_matrix,
|
||||||
|
nodes_for_apply_split, p_tree);
|
||||||
|
}
|
||||||
|
|
||||||
SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree);
|
SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree);
|
||||||
|
|
||||||
if (param_.max_depth == 0 || depth < param_.max_depth) {
|
if (param_.max_depth == 0 || depth < param_.max_depth) {
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (auto const& gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||||
this->histogram_builder_->BuildHist(
|
this->histogram_builder_->BuildHist(i, gidx, p_tree, partitioner_.front().Partitions(),
|
||||||
i, gidx, p_tree, row_set_collection_,
|
nodes_for_explicit_hist_build_,
|
||||||
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_,
|
nodes_for_subtraction_trick_, gpair_h);
|
||||||
gpair_h);
|
|
||||||
++i;
|
++i;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -293,7 +302,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
|||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::Update(
|
void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||||
const GHistIndexMatrix &gmat,
|
const GHistIndexMatrix &gmat,
|
||||||
const ColumnMatrix &column_matrix,
|
const common::ColumnMatrix &column_matrix,
|
||||||
HostDeviceVector<GradientPair> *gpair,
|
HostDeviceVector<GradientPair> *gpair,
|
||||||
DMatrix *p_fmat, RegTree *p_tree) {
|
DMatrix *p_fmat, RegTree *p_tree) {
|
||||||
builder_monitor_.Start("Update");
|
builder_monitor_.Start("Update");
|
||||||
@ -333,14 +342,14 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
|||||||
|
|
||||||
CHECK_GT(out_preds.Size(), 0U);
|
CHECK_GT(out_preds.Size(), 0U);
|
||||||
|
|
||||||
size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin();
|
CHECK_EQ(partitioner_.size(), 1);
|
||||||
|
auto const &row_set_collection = this->partitioner_.front().Partitions();
|
||||||
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
|
size_t n_nodes = row_set_collection.end() - row_set_collection.begin();
|
||||||
return row_set_collection_[node].Size();
|
common::BlockedSpace2d space(
|
||||||
}, 1024);
|
n_nodes, [&](size_t node) { return partitioner_.front()[node].Size(); }, 1024);
|
||||||
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
|
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
|
||||||
common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node, common::Range1d r) {
|
common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node, common::Range1d r) {
|
||||||
const RowSetCollection::Elem rowset = row_set_collection_[node];
|
const common::RowSetCollection::Elem rowset = row_set_collection[node];
|
||||||
if (rowset.begin != nullptr && rowset.end != nullptr) {
|
if (rowset.begin != nullptr && rowset.end != nullptr) {
|
||||||
int nid = rowset.node_id;
|
int nid = rowset.node_id;
|
||||||
bst_float leaf_value;
|
bst_float leaf_value;
|
||||||
@ -354,7 +363,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
|||||||
}
|
}
|
||||||
leaf_value = (*p_last_tree_)[nid].LeafValue();
|
leaf_value = (*p_last_tree_)[nid].LeafValue();
|
||||||
|
|
||||||
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||||
out_preds(*it) += leaf_value;
|
out_preds(*it) += leaf_value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -364,10 +373,9 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix& fmat,
|
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix& fmat,
|
||||||
std::vector<GradientPair>* gpair,
|
std::vector<GradientPair>* gpair) {
|
||||||
std::vector<size_t>* row_indices) {
|
|
||||||
const auto& info = fmat.Info();
|
const auto& info = fmat.Info();
|
||||||
auto& rnd = common::GlobalRandom();
|
auto& rnd = common::GlobalRandom();
|
||||||
std::vector<GradientPair>& gpair_ref = *gpair;
|
std::vector<GradientPair>& gpair_ref = *gpair;
|
||||||
@ -410,101 +418,31 @@ template <typename GradientSumT>
|
|||||||
void QuantileHistMaker::Builder<GradientSumT>::InitData(
|
void QuantileHistMaker::Builder<GradientSumT>::InitData(
|
||||||
const GHistIndexMatrix &gmat, const DMatrix &fmat, const RegTree &tree,
|
const GHistIndexMatrix &gmat, const DMatrix &fmat, const RegTree &tree,
|
||||||
std::vector<GradientPair> *gpair) {
|
std::vector<GradientPair> *gpair) {
|
||||||
CHECK((param_.max_depth > 0 || param_.max_leaves > 0))
|
|
||||||
<< "max_depth or max_leaves cannot be both 0 (unlimited); "
|
|
||||||
<< "at least one should be a positive quantity.";
|
|
||||||
if (param_.grow_policy == TrainParam::kDepthWise) {
|
|
||||||
CHECK(param_.max_depth > 0) << "max_depth cannot be 0 (unlimited) "
|
|
||||||
<< "when grow_policy is depthwise.";
|
|
||||||
}
|
|
||||||
builder_monitor_.Start("InitData");
|
builder_monitor_.Start("InitData");
|
||||||
const auto& info = fmat.Info();
|
const auto& info = fmat.Info();
|
||||||
|
|
||||||
{
|
{
|
||||||
// initialize the row set
|
|
||||||
row_set_collection_.Clear();
|
|
||||||
// initialize histogram collection
|
// initialize histogram collection
|
||||||
uint32_t nbins = gmat.cut.Ptrs().back();
|
uint32_t nbins = gmat.cut.Ptrs().back();
|
||||||
// initialize histogram builder
|
// initialize histogram builder
|
||||||
dmlc::OMPException exc;
|
dmlc::OMPException exc;
|
||||||
exc.Rethrow();
|
|
||||||
this->histogram_builder_->Reset(nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin},
|
this->histogram_builder_->Reset(nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin},
|
||||||
this->ctx_->Threads(), 1, rabit::IsDistributed());
|
this->ctx_->Threads(), 1, rabit::IsDistributed());
|
||||||
|
|
||||||
std::vector<size_t>& row_indices = *row_set_collection_.Data();
|
|
||||||
row_indices.resize(info.num_row_);
|
|
||||||
size_t* p_row_indices = row_indices.data();
|
|
||||||
// mark subsample and build list of member rows
|
|
||||||
|
|
||||||
if (param_.subsample < 1.0f) {
|
if (param_.subsample < 1.0f) {
|
||||||
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
|
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
|
||||||
<< "Only uniform sampling is supported, "
|
<< "Only uniform sampling is supported, "
|
||||||
<< "gradient-based sampling is only support by GPU Hist.";
|
<< "gradient-based sampling is only support by GPU Hist.";
|
||||||
builder_monitor_.Start("InitSampling");
|
builder_monitor_.Start("InitSampling");
|
||||||
InitSampling(fmat, gpair, &row_indices);
|
InitSampling(fmat, gpair);
|
||||||
builder_monitor_.Stop("InitSampling");
|
builder_monitor_.Stop("InitSampling");
|
||||||
CHECK_EQ(row_indices.size(), info.num_row_);
|
|
||||||
// We should check that the partitioning was done correctly
|
// We should check that the partitioning was done correctly
|
||||||
// and each row of the dataset fell into exactly one of the categories
|
// and each row of the dataset fell into exactly one of the categories
|
||||||
}
|
}
|
||||||
auto n_threads = this->ctx_->Threads();
|
|
||||||
common::MemStackAllocator<bool, 128> buff(n_threads);
|
|
||||||
bool* p_buff = buff.Get();
|
|
||||||
std::fill(p_buff, p_buff + this->ctx_->Threads(), false);
|
|
||||||
|
|
||||||
const size_t block_size = info.num_row_ / n_threads + !!(info.num_row_ % n_threads);
|
|
||||||
|
|
||||||
#pragma omp parallel num_threads(n_threads)
|
|
||||||
{
|
|
||||||
exc.Run([&]() {
|
|
||||||
const size_t tid = omp_get_thread_num();
|
|
||||||
const size_t ibegin = tid * block_size;
|
|
||||||
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
|
|
||||||
static_cast<size_t>(info.num_row_));
|
|
||||||
|
|
||||||
for (size_t i = ibegin; i < iend; ++i) {
|
|
||||||
if ((*gpair)[i].GetHess() < 0.0f) {
|
|
||||||
p_buff[tid] = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
exc.Rethrow();
|
|
||||||
|
|
||||||
bool has_neg_hess = false;
|
|
||||||
for (int32_t tid = 0; tid < n_threads; ++tid) {
|
|
||||||
if (p_buff[tid]) {
|
|
||||||
has_neg_hess = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_neg_hess) {
|
|
||||||
size_t j = 0;
|
|
||||||
for (size_t i = 0; i < info.num_row_; ++i) {
|
|
||||||
if ((*gpair)[i].GetHess() >= 0.0f) {
|
|
||||||
p_row_indices[j++] = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
row_indices.resize(j);
|
|
||||||
} else {
|
|
||||||
#pragma omp parallel num_threads(n_threads)
|
|
||||||
{
|
|
||||||
exc.Run([&]() {
|
|
||||||
const size_t tid = omp_get_thread_num();
|
|
||||||
const size_t ibegin = tid * block_size;
|
|
||||||
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size),
|
|
||||||
static_cast<size_t>(info.num_row_));
|
|
||||||
for (size_t i = ibegin; i < iend; ++i) {
|
|
||||||
p_row_indices[i] = i;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
exc.Rethrow();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
row_set_collection_.Init();
|
partitioner_.clear();
|
||||||
|
partitioner_.emplace_back(info.num_row_, 0, this->ctx_->Threads());
|
||||||
|
|
||||||
{
|
{
|
||||||
/* determine layout of data */
|
/* determine layout of data */
|
||||||
@ -558,12 +496,9 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
|
|||||||
builder_monitor_.Stop("InitData");
|
builder_monitor_.Stop("InitData");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
void HistRowPartitioner::FindSplitConditions(const std::vector<CPUExpandEntry> &nodes,
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions(
|
const RegTree &tree, const GHistIndexMatrix &gmat,
|
||||||
const std::vector<CPUExpandEntry>& nodes,
|
std::vector<int32_t> *split_conditions) {
|
||||||
const RegTree& tree,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
std::vector<int32_t>* split_conditions) {
|
|
||||||
const size_t n_nodes = nodes.size();
|
const size_t n_nodes = nodes.size();
|
||||||
split_conditions->resize(n_nodes);
|
split_conditions->resize(n_nodes);
|
||||||
|
|
||||||
@ -576,8 +511,7 @@ void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions(
|
|||||||
int32_t split_cond = -1;
|
int32_t split_cond = -1;
|
||||||
// convert floating-point split_pt into corresponding bin_id
|
// convert floating-point split_pt into corresponding bin_id
|
||||||
// split_cond = -1 indicates that split_pt is less than all known cut points
|
// split_cond = -1 indicates that split_pt is less than all known cut points
|
||||||
CHECK_LT(upper_bound,
|
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
|
||||||
for (uint32_t bound = lower_bound; bound < upper_bound; ++bound) {
|
for (uint32_t bound = lower_bound; bound < upper_bound; ++bound) {
|
||||||
if (split_pt == gmat.cut.Values()[bound]) {
|
if (split_pt == gmat.cut.Values()[bound]) {
|
||||||
split_cond = static_cast<int32_t>(bound);
|
split_cond = static_cast<int32_t>(bound);
|
||||||
@ -586,88 +520,20 @@ void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions(
|
|||||||
(*split_conditions)[i] = split_cond;
|
(*split_conditions)[i] = split_cond;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template <typename GradientSumT>
|
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToRowSet(
|
void HistRowPartitioner::AddSplitsToRowSet(const std::vector<CPUExpandEntry> &nodes,
|
||||||
const std::vector<CPUExpandEntry>& nodes,
|
RegTree const *p_tree) {
|
||||||
RegTree* p_tree) {
|
|
||||||
const size_t n_nodes = nodes.size();
|
const size_t n_nodes = nodes.size();
|
||||||
for (unsigned int i = 0; i < n_nodes; ++i) {
|
for (unsigned int i = 0; i < n_nodes; ++i) {
|
||||||
const int32_t nid = nodes[i].nid;
|
const int32_t nid = nodes[i].nid;
|
||||||
const size_t n_left = partition_builder_.GetNLeftElems(i);
|
const size_t n_left = partition_builder_.GetNLeftElems(i);
|
||||||
const size_t n_right = partition_builder_.GetNRightElems(i);
|
const size_t n_right = partition_builder_.GetNRightElems(i);
|
||||||
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild());
|
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild());
|
||||||
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(),
|
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(),
|
||||||
(*p_tree)[nid].RightChild(), n_left, n_right);
|
n_left, n_right);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
|
||||||
template <bool any_missing>
|
|
||||||
void QuantileHistMaker::Builder<GradientSumT>::ApplySplit(const std::vector<CPUExpandEntry> nodes,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
const ColumnMatrix& column_matrix,
|
|
||||||
RegTree* p_tree) {
|
|
||||||
builder_monitor_.Start("ApplySplit");
|
|
||||||
// 1. Find split condition for each split
|
|
||||||
const size_t n_nodes = nodes.size();
|
|
||||||
std::vector<int32_t> split_conditions;
|
|
||||||
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
|
|
||||||
// 2.1 Create a blocked space of size SUM(samples in each node)
|
|
||||||
common::BlockedSpace2d space(n_nodes, [&](size_t node_in_set) {
|
|
||||||
int32_t nid = nodes[node_in_set].nid;
|
|
||||||
return row_set_collection_[nid].Size();
|
|
||||||
}, kPartitionBlockSize);
|
|
||||||
// 2.2 Initialize the partition builder
|
|
||||||
// allocate buffers for storage intermediate results by each thread
|
|
||||||
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
|
|
||||||
const int32_t nid = nodes[node_in_set].nid;
|
|
||||||
const size_t size = row_set_collection_[nid].Size();
|
|
||||||
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
|
|
||||||
return n_tasks;
|
|
||||||
});
|
|
||||||
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
|
|
||||||
// Store results in intermediate buffers from partition_builder_
|
|
||||||
common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
|
||||||
size_t begin = r.begin();
|
|
||||||
const int32_t nid = nodes[node_in_set].nid;
|
|
||||||
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
|
|
||||||
partition_builder_.AllocateForTask(task_id);
|
|
||||||
switch (column_matrix.GetTypeSize()) {
|
|
||||||
case common::kUint8BinsTypeSize:
|
|
||||||
partition_builder_.template Partition<uint8_t, any_missing>(node_in_set, nid, r,
|
|
||||||
split_conditions[node_in_set], column_matrix,
|
|
||||||
*p_tree, row_set_collection_[nid].begin);
|
|
||||||
break;
|
|
||||||
case common::kUint16BinsTypeSize:
|
|
||||||
partition_builder_.template Partition<uint16_t, any_missing>(node_in_set, nid, r,
|
|
||||||
split_conditions[node_in_set], column_matrix,
|
|
||||||
*p_tree, row_set_collection_[nid].begin);
|
|
||||||
break;
|
|
||||||
case common::kUint32BinsTypeSize:
|
|
||||||
partition_builder_.template Partition<uint32_t, any_missing>(node_in_set, nid, r,
|
|
||||||
split_conditions[node_in_set], column_matrix,
|
|
||||||
*p_tree, row_set_collection_[nid].begin);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
CHECK(false); // no default behavior
|
|
||||||
}
|
|
||||||
});
|
|
||||||
// 3. Compute offsets to copy blocks of row-indexes
|
|
||||||
// from partition_builder_ to row_set_collection_
|
|
||||||
partition_builder_.CalculateRowOffsets();
|
|
||||||
|
|
||||||
// 4. Copy elements from partition_builder_ to row_set_collection_ back
|
|
||||||
// with updated row-indexes for each tree-node
|
|
||||||
common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
|
||||||
const int32_t nid = nodes[node_in_set].nid;
|
|
||||||
partition_builder_.MergeToArray(node_in_set, r.begin(),
|
|
||||||
const_cast<size_t*>(row_set_collection_[nid].begin));
|
|
||||||
});
|
|
||||||
// 5. Add info about splits into row_set_collection_
|
|
||||||
AddSplitsToRowSet(nodes, p_tree);
|
|
||||||
builder_monitor_.Stop("ApplySplit");
|
|
||||||
}
|
|
||||||
|
|
||||||
template struct QuantileHistMaker::Builder<float>;
|
template struct QuantileHistMaker::Builder<float>;
|
||||||
template struct QuantileHistMaker::Builder<double>;
|
template struct QuantileHistMaker::Builder<double>;
|
||||||
|
|
||||||
|
|||||||
@ -11,9 +11,9 @@
|
|||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
#include <iomanip>
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <queue>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -38,8 +38,6 @@
|
|||||||
#include "../common/column_matrix.h"
|
#include "../common/column_matrix.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
|
|
||||||
struct RandomReplace {
|
struct RandomReplace {
|
||||||
public:
|
public:
|
||||||
// similar value as for minstd_rand
|
// similar value as for minstd_rand
|
||||||
@ -82,15 +80,127 @@ struct RandomReplace {
|
|||||||
};
|
};
|
||||||
|
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
class HistRowPartitioner {
|
||||||
|
// heuristically chosen block size of parallel partitioning
|
||||||
|
static constexpr size_t kPartitionBlockSize = 2048;
|
||||||
|
// worker class that partition a block of rows
|
||||||
|
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
||||||
|
// storage for row index
|
||||||
|
common::RowSetCollection row_set_collection_;
|
||||||
|
|
||||||
using xgboost::GHistIndexMatrix;
|
/**
|
||||||
using xgboost::common::GHistIndexRow;
|
* \brief Turn split values into discrete bin indices.
|
||||||
using xgboost::common::HistCollection;
|
*/
|
||||||
using xgboost::common::RowSetCollection;
|
static void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
|
||||||
using xgboost::common::GHistRow;
|
const GHistIndexMatrix& gmat,
|
||||||
using xgboost::common::GHistBuilder;
|
std::vector<int32_t>* split_conditions);
|
||||||
using xgboost::common::ColumnMatrix;
|
/**
|
||||||
using xgboost::common::Column;
|
* \brief Update the row set for new splits specifed by nodes.
|
||||||
|
*/
|
||||||
|
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree const* p_tree);
|
||||||
|
|
||||||
|
public:
|
||||||
|
bst_row_t base_rowid = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
HistRowPartitioner(size_t n_samples, size_t base_rowid, int32_t n_threads) {
|
||||||
|
row_set_collection_.Clear();
|
||||||
|
const size_t block_size = n_samples / n_threads + !!(n_samples % n_threads);
|
||||||
|
dmlc::OMPException exc;
|
||||||
|
std::vector<size_t>& row_indices = *row_set_collection_.Data();
|
||||||
|
row_indices.resize(n_samples);
|
||||||
|
size_t* p_row_indices = row_indices.data();
|
||||||
|
// parallel initialization o f row indices. (std::iota)
|
||||||
|
#pragma omp parallel num_threads(n_threads)
|
||||||
|
{
|
||||||
|
exc.Run([&]() {
|
||||||
|
const size_t tid = omp_get_thread_num();
|
||||||
|
const size_t ibegin = tid * block_size;
|
||||||
|
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size), n_samples);
|
||||||
|
for (size_t i = ibegin; i < iend; ++i) {
|
||||||
|
p_row_indices[i] = i + base_rowid;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
row_set_collection_.Init();
|
||||||
|
this->base_rowid = base_rowid;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool any_missing, bool any_cat>
|
||||||
|
void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& gmat,
|
||||||
|
common::ColumnMatrix const& column_matrix,
|
||||||
|
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||||
|
// 1. Find split condition for each split
|
||||||
|
const size_t n_nodes = nodes.size();
|
||||||
|
std::vector<int32_t> split_conditions;
|
||||||
|
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
|
||||||
|
// 2.1 Create a blocked space of size SUM(samples in each node)
|
||||||
|
common::BlockedSpace2d space(
|
||||||
|
n_nodes,
|
||||||
|
[&](size_t node_in_set) {
|
||||||
|
int32_t nid = nodes[node_in_set].nid;
|
||||||
|
return row_set_collection_[nid].Size();
|
||||||
|
},
|
||||||
|
kPartitionBlockSize);
|
||||||
|
// 2.2 Initialize the partition builder
|
||||||
|
// allocate buffers for storage intermediate results by each thread
|
||||||
|
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
|
||||||
|
const int32_t nid = nodes[node_in_set].nid;
|
||||||
|
const size_t size = row_set_collection_[nid].Size();
|
||||||
|
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
|
||||||
|
return n_tasks;
|
||||||
|
});
|
||||||
|
CHECK_EQ(base_rowid, gmat.base_rowid);
|
||||||
|
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
|
||||||
|
// Store results in intermediate buffers from partition_builder_
|
||||||
|
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
||||||
|
size_t begin = r.begin();
|
||||||
|
const int32_t nid = nodes[node_in_set].nid;
|
||||||
|
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
|
||||||
|
partition_builder_.AllocateForTask(task_id);
|
||||||
|
switch (column_matrix.GetTypeSize()) {
|
||||||
|
case common::kUint8BinsTypeSize:
|
||||||
|
partition_builder_.template Partition<uint8_t, any_missing, any_cat>(
|
||||||
|
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
|
||||||
|
row_set_collection_[nid].begin);
|
||||||
|
break;
|
||||||
|
case common::kUint16BinsTypeSize:
|
||||||
|
partition_builder_.template Partition<uint16_t, any_missing, any_cat>(
|
||||||
|
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
|
||||||
|
row_set_collection_[nid].begin);
|
||||||
|
break;
|
||||||
|
case common::kUint32BinsTypeSize:
|
||||||
|
partition_builder_.template Partition<uint32_t, any_missing, any_cat>(
|
||||||
|
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
|
||||||
|
row_set_collection_[nid].begin);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// no default behavior
|
||||||
|
CHECK(false) << column_matrix.GetTypeSize();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// 3. Compute offsets to copy blocks of row-indexes
|
||||||
|
// from partition_builder_ to row_set_collection_
|
||||||
|
partition_builder_.CalculateRowOffsets();
|
||||||
|
|
||||||
|
// 4. Copy elements from partition_builder_ to row_set_collection_ back
|
||||||
|
// with updated row-indexes for each tree-node
|
||||||
|
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
||||||
|
const int32_t nid = nodes[node_in_set].nid;
|
||||||
|
partition_builder_.MergeToArray(node_in_set, r.begin(),
|
||||||
|
const_cast<size_t*>(row_set_collection_[nid].begin));
|
||||||
|
});
|
||||||
|
// 5. Add info about splits into row_set_collection_
|
||||||
|
AddSplitsToRowSet(nodes, p_tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const& Partitions() const { return row_set_collection_; }
|
||||||
|
size_t Size() const {
|
||||||
|
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
|
||||||
|
}
|
||||||
|
auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
|
||||||
|
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
|
||||||
|
};
|
||||||
|
|
||||||
inline BatchParam HistBatch(TrainParam const& param) {
|
inline BatchParam HistBatch(TrainParam const& param) {
|
||||||
return {param.max_bin, param.sparse_threshold};
|
return {param.max_bin, param.sparse_threshold};
|
||||||
@ -185,21 +295,7 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
|
|
||||||
size_t GetNumberOfTrees();
|
size_t GetNumberOfTrees();
|
||||||
|
|
||||||
void InitSampling(const DMatrix& fmat,
|
void InitSampling(const DMatrix& fmat, std::vector<GradientPair>* gpair);
|
||||||
std::vector<GradientPair>* gpair,
|
|
||||||
std::vector<size_t>* row_indices);
|
|
||||||
|
|
||||||
template <bool any_missing>
|
|
||||||
void ApplySplit(std::vector<CPUExpandEntry> nodes,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
const ColumnMatrix& column_matrix,
|
|
||||||
RegTree* p_tree);
|
|
||||||
|
|
||||||
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree* p_tree);
|
|
||||||
|
|
||||||
|
|
||||||
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
|
|
||||||
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions);
|
|
||||||
|
|
||||||
template <bool any_missing>
|
template <bool any_missing>
|
||||||
void InitRoot(DMatrix* p_fmat,
|
void InitRoot(DMatrix* p_fmat,
|
||||||
@ -221,7 +317,7 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
|
|
||||||
template <bool any_missing>
|
template <bool any_missing>
|
||||||
void ExpandTree(const GHistIndexMatrix& gmat,
|
void ExpandTree(const GHistIndexMatrix& gmat,
|
||||||
const ColumnMatrix& column_matrix,
|
const common::ColumnMatrix& column_matrix,
|
||||||
DMatrix* p_fmat,
|
DMatrix* p_fmat,
|
||||||
RegTree* p_tree,
|
RegTree* p_tree,
|
||||||
const std::vector<GradientPair>& gpair_h);
|
const std::vector<GradientPair>& gpair_h);
|
||||||
@ -232,9 +328,6 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
std::shared_ptr<common::ColumnSampler> column_sampler_{
|
std::shared_ptr<common::ColumnSampler> column_sampler_{
|
||||||
std::make_shared<common::ColumnSampler>()};
|
std::make_shared<common::ColumnSampler>()};
|
||||||
|
|
||||||
std::vector<size_t> unused_rows_;
|
|
||||||
// the internal row sets
|
|
||||||
RowSetCollection row_set_collection_;
|
|
||||||
std::vector<GradientPair> gpair_local_;
|
std::vector<GradientPair> gpair_local_;
|
||||||
|
|
||||||
/*! \brief feature with least # of bins. to be used for dense specialization
|
/*! \brief feature with least # of bins. to be used for dense specialization
|
||||||
@ -243,12 +336,12 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
|
|
||||||
std::unique_ptr<TreeUpdater> pruner_;
|
std::unique_ptr<TreeUpdater> pruner_;
|
||||||
std::unique_ptr<HistEvaluator<GradientSumT, CPUExpandEntry>> evaluator_;
|
std::unique_ptr<HistEvaluator<GradientSumT, CPUExpandEntry>> evaluator_;
|
||||||
|
// Right now there's only 1 partitioner in this vector, when external memory is fully
|
||||||
static constexpr size_t kPartitionBlockSize = 2048;
|
// supported we will have number of partitioners equal to number of pages.
|
||||||
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
std::vector<HistRowPartitioner> partitioner_;
|
||||||
|
|
||||||
// back pointers to tree and data matrix
|
// back pointers to tree and data matrix
|
||||||
const RegTree* p_last_tree_;
|
const RegTree* p_last_tree_{nullptr};
|
||||||
DMatrix const* const p_last_fmat_;
|
DMatrix const* const p_last_fmat_;
|
||||||
DMatrix* p_last_fmat_mutable_;
|
DMatrix* p_last_fmat_mutable_;
|
||||||
|
|
||||||
|
|||||||
@ -40,7 +40,7 @@ template <typename GradientSumT> void TestEvaluateSplits() {
|
|||||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||||
row_set_collection.Init();
|
row_set_collection.Init();
|
||||||
|
|
||||||
auto hist_builder = GHistBuilder<GradientSumT>(gmat.cut.Ptrs().back());
|
auto hist_builder = common::GHistBuilder<GradientSumT>(gmat.cut.Ptrs().back());
|
||||||
hist.Init(gmat.cut.Ptrs().back());
|
hist.Init(gmat.cut.Ptrs().back());
|
||||||
hist.AddHistRow(0);
|
hist.AddHistRow(0);
|
||||||
hist.AllocateAllData();
|
hist.AllocateAllData();
|
||||||
@ -94,7 +94,7 @@ TEST(HistEvaluator, Apply) {
|
|||||||
RegTree tree;
|
RegTree tree;
|
||||||
int static constexpr kNRows = 8, kNCols = 16;
|
int static constexpr kNRows = 8, kNCols = 16;
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
param.UpdateAllowUnknown(Args{{}});
|
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}});
|
||||||
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
auto evaluator_ = HistEvaluator<float, CPUExpandEntry>{param, dmat->Info(), 4, sampler,
|
auto evaluator_ = HistEvaluator<float, CPUExpandEntry>{param, dmat->Info(), 4, sampler,
|
||||||
@ -102,12 +102,22 @@ TEST(HistEvaluator, Apply) {
|
|||||||
|
|
||||||
CPUExpandEntry entry{0, 0, 10.0f};
|
CPUExpandEntry entry{0, 0, 10.0f};
|
||||||
entry.split.left_sum = GradStats{0.4, 0.6f};
|
entry.split.left_sum = GradStats{0.4, 0.6f};
|
||||||
entry.split.right_sum = GradStats{0.5, 0.7f};
|
entry.split.right_sum = GradStats{0.5, 0.5f};
|
||||||
|
|
||||||
evaluator_.ApplyTreeSplit(entry, &tree);
|
evaluator_.ApplyTreeSplit(entry, &tree);
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||||
ASSERT_EQ(tree.Stat(tree[0].LeftChild()).sum_hess, 0.6f);
|
ASSERT_EQ(tree.Stat(tree[0].LeftChild()).sum_hess, 0.6f);
|
||||||
ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f);
|
ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.5f);
|
||||||
|
|
||||||
|
{
|
||||||
|
RegTree tree;
|
||||||
|
entry.split.is_cat = true;
|
||||||
|
entry.split.split_value = 1.0;
|
||||||
|
evaluator_.ApplyTreeSplit(entry, &tree);
|
||||||
|
auto l = entry.split.left_sum;
|
||||||
|
ASSERT_NEAR(tree[1].LeafValue(), -l.sum_grad / l.sum_hess * param.learning_rate, kRtEps);
|
||||||
|
ASSERT_NEAR(tree[2].LeafValue(), -param.learning_rate, kRtEps);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestPartitionBasedSplit, CPUHist) {
|
TEST_F(TestPartitionBasedSplit, CPUHist) {
|
||||||
|
|||||||
@ -1,26 +1,14 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2021 XGBoost contributors
|
* Copyright 2021-2022, XGBoost contributors.
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include "../../../src/tree/updater_approx.h"
|
#include "../../../src/tree/updater_approx.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
#include "test_partitioner.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
namespace {
|
|
||||||
void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
|
|
||||||
tree->ExpandNode(
|
|
||||||
/*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
|
|
||||||
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
|
||||||
/*left_sum=*/0.0f,
|
|
||||||
/*right_sum=*/0.0f);
|
|
||||||
candidates->front().split.split_value = split_value;
|
|
||||||
candidates->front().split.sindex = 0;
|
|
||||||
candidates->front().split.sindex |= (1U << 31);
|
|
||||||
}
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
TEST(Approx, Partitioner) {
|
TEST(Approx, Partitioner) {
|
||||||
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
|
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
|
||||||
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
||||||
|
|||||||
21
tests/cpp/tree/test_partitioner.h
Normal file
21
tests/cpp/tree/test_partitioner.h
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021-2022, XGBoost contributors.
|
||||||
|
*/
|
||||||
|
#include <xgboost/tree_model.h>
|
||||||
|
#include <vector>
|
||||||
|
#include "../../../src/tree/hist/expand_entry.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
|
||||||
|
tree->ExpandNode(
|
||||||
|
/*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
|
||||||
|
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f,
|
||||||
|
/*right_sum=*/0.0f);
|
||||||
|
candidates->front().split.split_value = split_value;
|
||||||
|
candidates->front().split.sindex = 0;
|
||||||
|
candidates->front().split.sindex |= (1U << 31);
|
||||||
|
}
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,18 +1,19 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2018-2022 by XGBoost Contributors
|
* Copyright 2018-2022 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
#include <gtest/gtest.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "../helpers.h"
|
|
||||||
#include "../../../src/tree/param.h"
|
#include "../../../src/tree/param.h"
|
||||||
#include "../../../src/tree/updater_quantile_hist.h"
|
|
||||||
#include "../../../src/tree/split_evaluator.h"
|
#include "../../../src/tree/split_evaluator.h"
|
||||||
|
#include "../../../src/tree/updater_quantile_hist.h"
|
||||||
|
#include "../helpers.h"
|
||||||
|
#include "test_partitioner.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -94,130 +95,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestInitDataSampling(const GHistIndexMatrix& gmat,
|
|
||||||
std::vector<GradientPair>* gpair,
|
|
||||||
DMatrix* p_fmat,
|
|
||||||
const RegTree& tree) {
|
|
||||||
// check SimpleSkip
|
|
||||||
size_t initial_seed = 777;
|
|
||||||
std::linear_congruential_engine<std::uint_fast64_t, 16807, 0,
|
|
||||||
static_cast<uint64_t>(1) << 63 > eng_first(initial_seed);
|
|
||||||
for (size_t i = 0; i < 100; ++i) {
|
|
||||||
eng_first();
|
|
||||||
}
|
|
||||||
uint64_t initial_seed_th = RandomReplace::SimpleSkip(100, initial_seed, 16807, RandomReplace::kMod);
|
|
||||||
std::linear_congruential_engine<std::uint_fast64_t, RandomReplace::kBase, 0,
|
|
||||||
RandomReplace::kMod > eng_second(initial_seed_th);
|
|
||||||
ASSERT_EQ(eng_first(), eng_second());
|
|
||||||
|
|
||||||
const size_t nthreads = omp_get_num_threads();
|
|
||||||
// save state of global rng engine
|
|
||||||
auto initial_rnd = common::GlobalRandom();
|
|
||||||
std::vector<size_t> unused_rows_cpy = this->unused_rows_;
|
|
||||||
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
|
|
||||||
std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data());
|
|
||||||
std::vector<size_t> unused_row_indices_initial = this->unused_rows_;
|
|
||||||
ASSERT_EQ(row_indices_initial.size(), p_fmat->Info().num_row_);
|
|
||||||
auto check_each_row_occurs_in_one_of_arrays = [](const std::vector<size_t>& first,
|
|
||||||
const std::vector<size_t>& second,
|
|
||||||
size_t nrows) {
|
|
||||||
ASSERT_EQ(first.size(), nrows);
|
|
||||||
ASSERT_EQ(second.size(), 0);
|
|
||||||
};
|
|
||||||
check_each_row_occurs_in_one_of_arrays(row_indices_initial, unused_row_indices_initial,
|
|
||||||
p_fmat->Info().num_row_);
|
|
||||||
|
|
||||||
for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) {
|
|
||||||
omp_set_num_threads(i_nthreads);
|
|
||||||
// return initial state of global rng engine
|
|
||||||
common::GlobalRandom() = initial_rnd;
|
|
||||||
this->unused_rows_ = unused_rows_cpy;
|
|
||||||
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
|
|
||||||
std::vector<size_t>& row_indices = *(this->row_set_collection_.Data());
|
|
||||||
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
|
|
||||||
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
|
|
||||||
ASSERT_EQ(row_indices_initial[i], row_indices[i]);
|
|
||||||
}
|
|
||||||
std::vector<size_t>& unused_row_indices = this->unused_rows_;
|
|
||||||
ASSERT_EQ(unused_row_indices_initial.size(), unused_row_indices.size());
|
|
||||||
for (size_t i = 0; i < unused_row_indices_initial.size(); ++i) {
|
|
||||||
ASSERT_EQ(unused_row_indices_initial[i], unused_row_indices[i]);
|
|
||||||
}
|
|
||||||
check_each_row_occurs_in_one_of_arrays(row_indices, unused_row_indices,
|
|
||||||
p_fmat->Info().num_row_);
|
|
||||||
}
|
|
||||||
omp_set_num_threads(nthreads);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestApplySplit(const RegTree& tree) {
|
|
||||||
std::vector<GradientPair> row_gpairs =
|
|
||||||
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
|
||||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
|
|
||||||
int32_t constexpr kMaxBins = 4;
|
|
||||||
|
|
||||||
// try out different sparsity to get different number of missing values
|
|
||||||
for (double sparsity : {0.0, 0.1, 0.2}) {
|
|
||||||
// kNRows samples with kNCols features
|
|
||||||
auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix();
|
|
||||||
|
|
||||||
float sparse_th = 0.0;
|
|
||||||
GHistIndexMatrix gmat{dmat.get(), kMaxBins, sparse_th, false, common::OmpGetNumThreads(0)};
|
|
||||||
ColumnMatrix cm;
|
|
||||||
|
|
||||||
// treat everything as dense, as this is what we intend to test here
|
|
||||||
cm.Init(gmat, sparse_th, common::OmpGetNumThreads(0));
|
|
||||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
|
||||||
const size_t num_row = dmat->Info().num_row_;
|
|
||||||
// split by feature 0
|
|
||||||
const size_t bin_id_min = gmat.cut.Ptrs()[0];
|
|
||||||
const size_t bin_id_max = gmat.cut.Ptrs()[1];
|
|
||||||
|
|
||||||
// attempt to split at different bins
|
|
||||||
for (size_t split = 0; split < 4; split++) {
|
|
||||||
size_t left_cnt = 0, right_cnt = 0;
|
|
||||||
|
|
||||||
// manually compute how many samples go left or right
|
|
||||||
for (size_t rid = 0; rid < num_row; ++rid) {
|
|
||||||
for (size_t offset = gmat.row_ptr[rid]; offset < gmat.row_ptr[rid + 1]; ++offset) {
|
|
||||||
const size_t bin_id = gmat.index[offset];
|
|
||||||
if (bin_id >= bin_id_min && bin_id < bin_id_max) {
|
|
||||||
if (bin_id <= split) {
|
|
||||||
left_cnt++;
|
|
||||||
} else {
|
|
||||||
right_cnt++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if any were missing due to sparsity, we add them to the left or to the right
|
|
||||||
size_t missing = kNRows - left_cnt - right_cnt;
|
|
||||||
if (tree[0].DefaultLeft()) {
|
|
||||||
left_cnt += missing;
|
|
||||||
} else {
|
|
||||||
right_cnt += missing;
|
|
||||||
}
|
|
||||||
|
|
||||||
// have one node with kNRows (=8 at the moment) rows, just one task
|
|
||||||
RealImpl::partition_builder_.Init(1, 1, [&](size_t node_in_set) {
|
|
||||||
return 1;
|
|
||||||
});
|
|
||||||
const size_t task_id = RealImpl::partition_builder_.GetTaskIdx(0, 0);
|
|
||||||
RealImpl::partition_builder_.AllocateForTask(task_id);
|
|
||||||
if (cm.AnyMissing()) {
|
|
||||||
RealImpl::partition_builder_.template Partition<uint8_t, true>(0, 0, common::Range1d(0, kNRows),
|
|
||||||
split, cm, tree, this->row_set_collection_[0].begin);
|
|
||||||
} else {
|
|
||||||
RealImpl::partition_builder_.template Partition<uint8_t, false>(0, 0, common::Range1d(0, kNRows),
|
|
||||||
split, cm, tree, this->row_set_collection_[0].begin);
|
|
||||||
}
|
|
||||||
RealImpl::partition_builder_.CalculateRowOffsets();
|
|
||||||
ASSERT_EQ(RealImpl::partition_builder_.GetNLeftElems(0), left_cnt);
|
|
||||||
ASSERT_EQ(RealImpl::partition_builder_.GetNRightElems(0), right_cnt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
int static constexpr kNRows = 8, kNCols = 16;
|
int static constexpr kNRows = 8, kNCols = 16;
|
||||||
@ -262,33 +139,6 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
|
float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestInitDataSampling() {
|
|
||||||
int32_t constexpr kMaxBins = 4;
|
|
||||||
GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)};
|
|
||||||
|
|
||||||
RegTree tree = RegTree();
|
|
||||||
tree.param.UpdateAllowUnknown(cfg_);
|
|
||||||
|
|
||||||
std::vector<GradientPair> gpair =
|
|
||||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
|
||||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
|
||||||
if (double_builder_) {
|
|
||||||
double_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree);
|
|
||||||
} else {
|
|
||||||
float_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestApplySplit() {
|
|
||||||
RegTree tree = RegTree();
|
|
||||||
tree.param.UpdateAllowUnknown(cfg_);
|
|
||||||
if (double_builder_) {
|
|
||||||
double_builder_->TestApplySplit(tree);
|
|
||||||
} else {
|
|
||||||
float_builder_->TestApplySplit(tree);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(QuantileHist, InitData) {
|
TEST(QuantileHist, InitData) {
|
||||||
@ -301,30 +151,62 @@ TEST(QuantileHist, InitData) {
|
|||||||
maker_float.TestInitData();
|
maker_float.TestInitData();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QuantileHist, InitDataSampling) {
|
TEST(QuantileHist, Partitioner) {
|
||||||
const float subsample = 0.5;
|
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
|
||||||
std::vector<std::pair<std::string, std::string>> cfg
|
GenericParameter ctx;
|
||||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
|
ctx.InitAllowUnknown(Args{});
|
||||||
{"subsample", std::to_string(subsample)}};
|
|
||||||
QuantileHistMock maker(cfg);
|
|
||||||
maker.TestInitDataSampling();
|
|
||||||
const bool single_precision_histogram = true;
|
|
||||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
|
||||||
maker_float.TestInitDataSampling();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(QuantileHist, ApplySplit) {
|
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
|
||||||
std::vector<std::pair<std::string, std::string>> cfg
|
ASSERT_EQ(partitioner.base_rowid, base_rowid);
|
||||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
|
ASSERT_EQ(partitioner.Size(), 1);
|
||||||
{"split_evaluator", "elastic_net"},
|
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
||||||
{"reg_lambda", "0"}, {"reg_alpha", "0"}, {"max_delta_step", "0"},
|
|
||||||
{"min_child_weight", "0"}};
|
|
||||||
QuantileHistMock maker(cfg);
|
|
||||||
maker.TestApplySplit();
|
|
||||||
const bool single_precision_histogram = true;
|
|
||||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
|
||||||
maker_float.TestApplySplit();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||||
|
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
||||||
|
|
||||||
|
auto grad = GenerateRandomGradients(n_samples);
|
||||||
|
std::vector<float> hess(grad.Size());
|
||||||
|
std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(),
|
||||||
|
[](auto gpair) { return gpair.GetHess(); });
|
||||||
|
|
||||||
|
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, 0.5})) {
|
||||||
|
bst_feature_t const split_ind = 0;
|
||||||
|
common::ColumnMatrix column_indices;
|
||||||
|
column_indices.Init(page, 0.5, ctx.Threads());
|
||||||
|
{
|
||||||
|
auto min_value = page.cut.MinValues()[split_ind];
|
||||||
|
RegTree tree;
|
||||||
|
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
|
||||||
|
GetSplit(&tree, min_value, &candidates);
|
||||||
|
partitioner.UpdatePosition<false, true>(&ctx, page, column_indices, candidates, &tree);
|
||||||
|
ASSERT_EQ(partitioner.Size(), 3);
|
||||||
|
ASSERT_EQ(partitioner[1].Size(), 0);
|
||||||
|
ASSERT_EQ(partitioner[2].Size(), n_samples);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()};
|
||||||
|
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
||||||
|
float split_value = page.cut.Values().at(ptr / 2);
|
||||||
|
RegTree tree;
|
||||||
|
GetSplit(&tree, split_value, &candidates);
|
||||||
|
auto left_nidx = tree[RegTree::kRoot].LeftChild();
|
||||||
|
partitioner.UpdatePosition<false, true>(&ctx, page, column_indices, candidates, &tree);
|
||||||
|
|
||||||
|
auto elem = partitioner[left_nidx];
|
||||||
|
ASSERT_LT(elem.Size(), n_samples);
|
||||||
|
ASSERT_GT(elem.Size(), 1);
|
||||||
|
for (auto it = elem.begin; it != elem.end; ++it) {
|
||||||
|
auto value = page.cut.Values().at(page.index[*it]);
|
||||||
|
ASSERT_LE(value, split_value);
|
||||||
|
}
|
||||||
|
auto right_nidx = tree[RegTree::kRoot].RightChild();
|
||||||
|
elem = partitioner[right_nidx];
|
||||||
|
for (auto it = elem.begin; it != elem.end; ++it) {
|
||||||
|
auto value = page.cut.Values().at(page.index[*it]);
|
||||||
|
ASSERT_GT(value, split_value) << *it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -245,3 +245,4 @@ class TestTreeMethod:
|
|||||||
@pytest.mark.skipif(**tm.no_pandas())
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
def test_categorical(self, rows, cols, rounds, cats):
|
def test_categorical(self, rows, cols, rounds, cats):
|
||||||
self.run_categorical_basic(rows, cols, rounds, cats, "approx")
|
self.run_categorical_basic(rows, cols, rounds, cats, "approx")
|
||||||
|
self.run_categorical_basic(rows, cols, rounds, cats, "hist")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user