Expand categorical node. (#6028)

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-08-25 18:53:57 +08:00
committed by GitHub
parent 9a4e8b1d81
commit 20c95be625
12 changed files with 340 additions and 103 deletions

View File

@@ -16,6 +16,7 @@
#if defined(__CUDACC__)
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include "device_helpers.cuh"
#endif // defined(__CUDACC__)
#include "xgboost/span.h"
@@ -54,23 +55,24 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr
*
* \tparam Direction Whether the bits start from left or from right.
*/
template <typename VT, typename Direction>
template <typename VT, typename Direction, bool IsConst = false>
struct BitFieldContainer {
using value_type = VT; // NOLINT
using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
using pointer = value_type*; // NOLINT
static value_type constexpr kValueSize = sizeof(value_type) * 8;
static value_type constexpr kOne = 1; // force correct type.
struct Pos {
value_type int_pos {0};
value_type bit_pos {0};
std::remove_const_t<value_type> int_pos {0};
std::remove_const_t<value_type> bit_pos {0};
};
private:
common::Span<value_type> bits_;
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");
public:
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
Pos pos_v;
if (pos == 0) {
@@ -92,7 +94,7 @@ struct BitFieldContainer {
/*\brief Compute the size of needed memory allocation. The returned value is in terms
* of number of elements with `BitFieldContainer::value_type'.
*/
static size_t ComputeStorageSize(size_t size) {
XGBOOST_DEVICE static size_t ComputeStorageSize(size_t size) {
return common::DivRoundUp(size, kValueSize);
}
#if defined(__CUDA_ARCH__)
@@ -134,19 +136,19 @@ struct BitFieldContainer {
#endif // defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
__device__ void Set(value_type pos) {
__device__ auto Set(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
AtomicOr(reinterpret_cast<BitFieldAtomicType*>(&value), set_bit);
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicOr(reinterpret_cast<Type *>(&value), set_bit);
}
__device__ void Clear(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
AtomicAnd(reinterpret_cast<BitFieldAtomicType*>(&value), clear_bit);
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
}
#else
void Set(value_type pos) {
@@ -165,6 +167,7 @@ struct BitFieldContainer {
XGBOOST_DEVICE bool Check(Pos pos_v) const {
pos_v = Direction::Shift(pos_v);
SPAN_LT(pos_v.int_pos, bits_.size());
value_type const value = bits_[pos_v.int_pos];
value_type const test_bit = kOne << pos_v.bit_pos;
value_type result = test_bit & value;
@@ -179,10 +182,11 @@ struct BitFieldContainer {
XGBOOST_DEVICE pointer Data() const { return bits_.data(); }
friend std::ostream& operator<<(std::ostream& os, BitFieldContainer<VT, Direction> field) {
inline friend std::ostream &
operator<<(std::ostream &os, BitFieldContainer<VT, Direction, IsConst> field) {
os << "Bits " << "storage size: " << field.bits_.size() << "\n";
for (typename common::Span<value_type>::index_type i = 0; i < field.bits_.size(); ++i) {
std::bitset<BitFieldContainer<VT, Direction>::kValueSize> bset(field.bits_[i]);
std::bitset<BitFieldContainer<VT, Direction, IsConst>::kValueSize> bset(field.bits_[i]);
os << bset << "\n";
}
return os;
@@ -190,9 +194,9 @@ struct BitFieldContainer {
};
// Bits start from left most bits (most significant bit).
template <typename VT>
struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT>> {
using Container = BitFieldContainer<VT, LBitsPolicy<VT>>;
template <typename VT, bool IsConst = false>
struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT, IsConst>, IsConst> {
using Container = BitFieldContainer<VT, LBitsPolicy<VT, IsConst>, IsConst>;
using Pos = typename Container::Pos;
using value_type = typename Container::value_type; // NOLINT
@@ -215,38 +219,13 @@ struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
}
};
// Format: <Direction>BitField<size of underlying type in bits>, underlying type must be unsigned.
// Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type
// must be unsigned.
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>;
using RBitField8 = BitFieldContainer<uint8_t, RBitsPolicy<unsigned char>>;
#if defined(__CUDACC__)
template <typename V, typename D>
inline void PrintDeviceBits(std::string name, BitFieldContainer<V, D> field) {
std::cout << "Bits: " << name << std::endl;
std::vector<typename BitFieldContainer<V, D>::value_type> h_field_bits(field.bits_.size());
thrust::copy(thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(field.bits_.data()),
thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(
field.bits_.data() + field.bits_.size()),
h_field_bits.data());
BitFieldContainer<V, D> h_field;
h_field.bits_ = {h_field_bits.data(), h_field_bits.data() + h_field_bits.size()};
std::cout << h_field;
}
inline void PrintDeviceStorage(std::string name, common::Span<int32_t> list) {
std::cout << name << std::endl;
std::vector<int32_t> h_list(list.size());
thrust::copy(thrust::device_ptr<int32_t>(list.data()),
thrust::device_ptr<int32_t>(list.data() + list.size()),
h_list.data());
for (auto v : h_list) {
std::cout << v << ", ";
}
std::cout << std::endl;
}
#endif // defined(__CUDACC__)
using LBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t>>;
using CLBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t, true>, true>;
} // namespace xgboost
#endif // XGBOOST_COMMON_BITFIELD_H_

50
src/common/categorical.h Normal file
View File

@@ -0,0 +1,50 @@
/*!
* Copyright 2020 by XGBoost Contributors
* \file categorical.h
*/
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
#define XGBOOST_COMMON_CATEGORICAL_H_
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/span.h"
#include "xgboost/parameter.h"
#include "bitfield.h"
namespace xgboost {
namespace common {
// Cast the categorical type.
template <typename T>
XGBOOST_DEVICE bst_cat_t AsCat(T const& v) {
return static_cast<bst_cat_t>(v);
}
/* \brief Whether is fidx a categorical feature.
*
* \param ft Feature type for all features.
* \param fidx Feature index.
* \return Whether feature pointed by fidx is categorical feature.
*/
inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx) {
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
}
/* \brief Whether should it traverse to left branch of a tree.
*
* For one hot split, go to left if it's NOT the matching category.
*/
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) {
auto pos = CLBitField32::ToBitPos(cat);
if (pos.int_pos >= cats.size()) {
return true;
}
CLBitField32 const s_cats(cats);
return !s_cats.Check(cat);
}
using CatBitField = LBitField32;
using KCatBitField = CLBitField32;
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_CATEGORICAL_H_

View File

@@ -275,6 +275,9 @@ Json& JsonNumber::operator[](int ind) {
bool JsonNumber::operator==(Value const& rhs) const {
if (!IsA<JsonNumber>(&rhs)) { return false; }
if (std::isinf(number_)) {
return std::isinf(Cast<JsonNumber const>(&rhs)->GetNumber());
}
return std::abs(number_ - Cast<JsonNumber const>(&rhs)->GetNumber()) < kRtEps;
}

View File

@@ -199,8 +199,10 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
types->emplace_back(FeatureType::kNumerical);
} else if (elem == "q") {
types->emplace_back(FeatureType::kNumerical);
} else if (elem == "categorical") {
types->emplace_back(FeatureType::kCategorical);
} else {
LOG(FATAL) << "All feature_types must be {int, float, i, q}";
LOG(FATAL) << "All feature_types must be one of {int, float, i, q, categorical}.";
}
}
}

View File

@@ -18,6 +18,7 @@
#include "param.h"
#include "../common/common.h"
#include "../common/categorical.h"
namespace xgboost {
// register tree parameter
@@ -662,6 +663,53 @@ bst_node_t RegTree::GetNumSplitNodes() const {
return splits;
}
void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
bool default_left, bst_float base_weight,
bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change,
float sum_hess, float left_sum, float right_sum,
bst_node_t leaf_right_child) {
int pleft = this->AllocNode();
int pright = this->AllocNode();
auto &node = nodes_[nid];
CHECK(node.IsLeaf());
node.SetLeftChild(pleft);
node.SetRightChild(pright);
nodes_[node.LeftChild()].SetParent(nid, true);
nodes_[node.RightChild()].SetParent(nid, false);
node.SetSplit(split_index, split_value, default_left);
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
this->Stat(nid) = {loss_change, sum_hess, base_weight};
this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
this->split_types_.at(nid) = FeatureType::kNumerical;
}
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
common::Span<uint32_t> split_cat, bool default_left,
bst_float base_weight,
bst_float left_leaf_weight,
bst_float right_leaf_weight,
bst_float loss_change, float sum_hess,
float left_sum, float right_sum) {
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
default_left, base_weight,
left_leaf_weight, right_leaf_weight, loss_change, sum_hess,
left_sum, right_sum);
size_t orig_size = split_categories_.size();
this->split_categories_.resize(orig_size + split_cat.size());
std::copy(split_cat.data(), split_cat.data() + split_cat.size(),
split_categories_.begin() + orig_size);
this->split_types_.at(nid) = FeatureType::kCategorical;
this->split_categories_segments_.at(nid).beg = orig_size;
this->split_categories_segments_.at(nid).size = split_cat.size();
}
void RegTree::Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam));
if (!DMLC_IO_NO_ENDIAN_SWAP) {
@@ -751,11 +799,24 @@ void RegTree::LoadModel(Json const& in) {
auto const& default_left = get<Array const>(in["default_left"]);
CHECK_EQ(default_left.size(), n_nodes);
bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend();
std::vector<Json> split_type;
std::vector<Json> categories;
if (has_cat) {
split_type = get<Array const>(in["split_type"]);
categories = get<Array const>(in["categories"]);
}
stats_.clear();
nodes_.clear();
stats_.resize(n_nodes);
nodes_.resize(n_nodes);
split_types_.resize(n_nodes);
split_categories_segments_.resize(n_nodes);
CHECK_EQ(n_nodes, split_categories_segments_.size());
for (int32_t i = 0; i < n_nodes; ++i) {
auto& s = stats_[i];
s.loss_chg = get<Number const>(loss_changes[i]);
@@ -771,6 +832,31 @@ void RegTree::LoadModel(Json const& in) {
float cond { get<Number const>(conds[i]) };
bool dft_left { get<Boolean const>(default_left[i]) };
n = Node{left, right, parent, ind, cond, dft_left};
if (has_cat) {
split_types_[i] =
static_cast<FeatureType>(get<Integer const>(split_type[i]));
auto const& j_categories = get<Array const>(categories[i]);
bst_cat_t max_cat { std::numeric_limits<bst_cat_t>::min() };
for (auto const& j_cat : j_categories) {
auto cat = common::AsCat(get<Integer const>(j_cat));
max_cat = std::max(max_cat, cat);
}
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
? 0
: common::KCatBitField::ComputeStorageSize(max_cat);
std::vector<uint32_t> cat_bits_storage(size);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto const& j_cat : j_categories) {
cat_bits.Set(common::AsCat(get<Integer const>(j_cat)));
}
auto begin = split_categories_.size();
split_categories_.resize(begin + cat_bits_storage.size());
std::copy(cat_bits_storage.begin(), cat_bits_storage.end(),
split_categories_.begin() + begin);
split_categories_segments_[i].beg = begin;
split_categories_segments_[i].size = cat_bits_storage.size();
}
}
deleted_nodes_.clear();
@@ -811,8 +897,11 @@ void RegTree::SaveModel(Json* p_out) const {
std::vector<Json> indices(n_nodes);
std::vector<Json> conds(n_nodes);
std::vector<Json> default_left(n_nodes);
std::vector<Json> split_type(n_nodes);
for (int32_t i = 0; i < n_nodes; ++i) {
std::vector<Json> categories(n_nodes);
for (bst_node_t i = 0; i < n_nodes; ++i) {
auto const& s = stats_[i];
loss_changes[i] = s.loss_chg;
sum_hessian[i] = s.sum_hess;
@@ -826,6 +915,24 @@ void RegTree::SaveModel(Json* p_out) const {
indices[i] = static_cast<I>(n.SplitIndex());
conds[i] = n.SplitCond();
default_left[i] = n.DefaultLeft();
std::vector<Json> categories_temp;
// This condition is only for being compatibale with older version of XGBoost model
// that doesn't have categorical data support.
if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);
split_type[i] = static_cast<I>(this->NodeSplitType(i));
auto beg = this->GetSplitCategoriesPtr().at(i).beg;
auto size = this->GetSplitCategoriesPtr().at(i).size;
auto node_categories = this->GetSplitCategories().subspan(beg, size);
common::KCatBitField const cat_bits(node_categories);
for (size_t i = 0; i < cat_bits.Size(); ++i) {
if (cat_bits.Check(i)) {
categories_temp.emplace_back(static_cast<Integer::Int>(i));
}
}
}
categories[i] = Array(categories_temp);
}
out["loss_changes"] = std::move(loss_changes);
@@ -839,6 +946,12 @@ void RegTree::SaveModel(Json* p_out) const {
out["split_indices"] = std::move(indices);
out["split_conditions"] = std::move(conds);
out["default_left"] = std::move(default_left);
out["categories"] = categories;
if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
out["split_type"] = std::move(split_type);
}
}
void RegTree::FillNodeMeanValues() {