Expand categorical node. (#6028)
Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -109,7 +109,8 @@ using bst_int = int32_t; // NOLINT
|
||||
using bst_ulong = uint64_t; // NOLINT
|
||||
/*! \brief float type, used for storing statistics */
|
||||
using bst_float = float; // NOLINT
|
||||
|
||||
/*! \brief Categorical value type. */
|
||||
using bst_cat_t = int32_t; // NOLINT
|
||||
/*! \brief Type for data column (feature) index. */
|
||||
using bst_feature_t = uint32_t; // NOLINT
|
||||
/*! \brief Type for data row index.
|
||||
|
||||
@@ -35,7 +35,8 @@ enum class DataType : uint8_t {
|
||||
};
|
||||
|
||||
enum class FeatureType : uint8_t {
|
||||
kNumerical
|
||||
kNumerical,
|
||||
kCategorical
|
||||
};
|
||||
|
||||
/*!
|
||||
@@ -309,12 +310,6 @@ class SparsePage {
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Push row block into the page.
|
||||
* \param batch the row batch.
|
||||
*/
|
||||
void Push(const dmlc::RowBlock<uint32_t>& batch);
|
||||
|
||||
/**
|
||||
* \brief Pushes external data batch onto this page
|
||||
*
|
||||
|
||||
@@ -101,6 +101,18 @@ namespace common {
|
||||
} while (0);
|
||||
#endif // __CUDA_ARCH__
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#define SPAN_LT(lhs, rhs) \
|
||||
if (!((lhs) < (rhs))) { \
|
||||
printf("%lu < %lu failed\n", static_cast<size_t>(lhs), \
|
||||
static_cast<size_t>(rhs)); \
|
||||
asm("trap;"); \
|
||||
}
|
||||
#else
|
||||
#define SPAN_LT(lhs, rhs) \
|
||||
SPAN_CHECK((lhs) < (rhs))
|
||||
#endif // defined(__CUDA_ARCH__)
|
||||
|
||||
namespace detail {
|
||||
/*!
|
||||
* By default, XGBoost uses uint32_t for indexing data. int64_t covers all
|
||||
@@ -515,7 +527,7 @@ class Span {
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE reference operator[](index_type _idx) const {
|
||||
SPAN_CHECK(_idx < size());
|
||||
SPAN_LT(_idx, size());
|
||||
return data()[_idx];
|
||||
}
|
||||
|
||||
@@ -575,7 +587,6 @@ class Span {
|
||||
detail::ExtentValue<Extent, Offset, Count>::value> {
|
||||
SPAN_CHECK((Count == dynamic_extent) ?
|
||||
(Offset <= size()) : (Offset + Count <= size()));
|
||||
|
||||
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
|
||||
}
|
||||
|
||||
|
||||
@@ -318,6 +318,8 @@ class RegTree : public Model {
|
||||
param.num_deleted = 0;
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param.num_nodes);
|
||||
for (int i = 0; i < param.num_nodes; i ++) {
|
||||
nodes_[i].SetLeaf(0.0f);
|
||||
nodes_[i].SetParent(kInvalidNodeId);
|
||||
@@ -412,30 +414,33 @@ class RegTree : public Model {
|
||||
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
|
||||
* some updaters use the right child index of leaf as a marker
|
||||
*/
|
||||
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
|
||||
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
|
||||
bool default_left, bst_float base_weight,
|
||||
bst_float left_leaf_weight, bst_float right_leaf_weight,
|
||||
bst_float loss_change, float sum_hess, float left_sum,
|
||||
float right_sum,
|
||||
bst_node_t leaf_right_child = kInvalidNodeId) {
|
||||
int pleft = this->AllocNode();
|
||||
int pright = this->AllocNode();
|
||||
auto &node = nodes_[nid];
|
||||
CHECK(node.IsLeaf());
|
||||
node.SetLeftChild(pleft);
|
||||
node.SetRightChild(pright);
|
||||
nodes_[node.LeftChild()].SetParent(nid, true);
|
||||
nodes_[node.RightChild()].SetParent(nid, false);
|
||||
node.SetSplit(split_index, split_value,
|
||||
default_left);
|
||||
bst_node_t leaf_right_child = kInvalidNodeId);
|
||||
|
||||
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
|
||||
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
|
||||
|
||||
this->Stat(nid) = {loss_change, sum_hess, base_weight};
|
||||
this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
|
||||
this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
|
||||
}
|
||||
/**
|
||||
* \brief Expands a leaf node with categories
|
||||
*
|
||||
* \param nid The node index to expand.
|
||||
* \param split_index Feature index of the split.
|
||||
* \param split_cat The bitset containing categories
|
||||
* \param default_left True to default left.
|
||||
* \param base_weight The base weight, before learning rate.
|
||||
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
|
||||
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
|
||||
* \param loss_change The loss change.
|
||||
* \param sum_hess The sum hess.
|
||||
* \param left_sum The sum hess of left leaf.
|
||||
* \param right_sum The sum hess of right leaf.
|
||||
*/
|
||||
void ExpandCategorical(bst_node_t nid, unsigned split_index,
|
||||
common::Span<uint32_t> split_cat, bool default_left,
|
||||
bst_float base_weight, bst_float left_leaf_weight,
|
||||
bst_float right_leaf_weight, bst_float loss_change,
|
||||
float sum_hess, float left_sum, float right_sum);
|
||||
|
||||
/*!
|
||||
* \brief get current depth
|
||||
@@ -588,6 +593,28 @@ class RegTree : public Model {
|
||||
* \brief calculate the mean value for each node, required for feature contributions
|
||||
*/
|
||||
void FillNodeMeanValues();
|
||||
/*!
|
||||
* \brief Get split type for a node.
|
||||
* \param nidx Index of node.
|
||||
* \return The type of this split. For leaf node it's always kNumerical.
|
||||
*/
|
||||
FeatureType NodeSplitType(bst_node_t nidx) const {
|
||||
return split_types_.at(nidx);
|
||||
}
|
||||
/*!
|
||||
* \brief Get split types for all nodes.
|
||||
*/
|
||||
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
|
||||
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
|
||||
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
|
||||
|
||||
// The fields of split_categories_segments_[i] are set such that
|
||||
// the range split_categories_[beg:(beg+size)] stores the bitset for
|
||||
// the matching categories for the i-th node.
|
||||
struct Segment {
|
||||
size_t beg {0};
|
||||
size_t size {0};
|
||||
};
|
||||
|
||||
private:
|
||||
// vector of nodes
|
||||
@@ -597,9 +624,16 @@ class RegTree : public Model {
|
||||
// stats of nodes
|
||||
std::vector<RTreeNodeStat> stats_;
|
||||
std::vector<bst_float> node_mean_values_;
|
||||
std::vector<FeatureType> split_types_;
|
||||
|
||||
// Categories for each internal node.
|
||||
std::vector<uint32_t> split_categories_;
|
||||
// Ptr to split categories of each node.
|
||||
std::vector<Segment> split_categories_segments_;
|
||||
|
||||
// allocate a new node,
|
||||
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
||||
int AllocNode() {
|
||||
bst_node_t AllocNode() {
|
||||
if (param.num_deleted != 0) {
|
||||
int nid = deleted_nodes_.back();
|
||||
deleted_nodes_.pop_back();
|
||||
@@ -612,6 +646,8 @@ class RegTree : public Model {
|
||||
<< "number of nodes in the tree exceed 2^31";
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param.num_nodes);
|
||||
return nd;
|
||||
}
|
||||
// delete a tree node, keep the parent field to allow trace back
|
||||
|
||||
Reference in New Issue
Block a user