CPU evaluation for cat data. (#7393)
* Implementation for one hot based. * Implementation for partition based. (LightGBM)
This commit is contained in:
parent
6ede12412c
commit
d7d1b6e3a6
@ -5,11 +5,12 @@
|
|||||||
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
|
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
|
||||||
#define XGBOOST_COMMON_CATEGORICAL_H_
|
#define XGBOOST_COMMON_CATEGORICAL_H_
|
||||||
|
|
||||||
|
#include "bitfield.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/span.h"
|
|
||||||
#include "xgboost/parameter.h"
|
#include "xgboost/parameter.h"
|
||||||
#include "bitfield.h"
|
#include "xgboost/span.h"
|
||||||
|
#include "xgboost/task.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
@ -47,6 +48,15 @@ inline void InvalidCategory() {
|
|||||||
"should be non-negative.";
|
"should be non-negative.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Whether should we use onehot encoding for categorical data.
|
||||||
|
*/
|
||||||
|
inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) {
|
||||||
|
bool use_one_hot = n_cats < max_cat_to_onehot ||
|
||||||
|
(task.task != ObjInfo::kRegression && task.task != ObjInfo::kBinary);
|
||||||
|
return use_one_hot;
|
||||||
|
}
|
||||||
|
|
||||||
struct IsCatOp {
|
struct IsCatOp {
|
||||||
XGBOOST_DEVICE bool operator()(FeatureType ft) {
|
XGBOOST_DEVICE bool operator()(FeatureType ft) {
|
||||||
return ft == FeatureType::kCategorical;
|
return ft == FeatureType::kCategorical;
|
||||||
|
|||||||
@ -6,13 +6,16 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <numeric>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "xgboost/task.h"
|
||||||
#include "../param.h"
|
#include "../param.h"
|
||||||
#include "../constraints.h"
|
#include "../constraints.h"
|
||||||
#include "../split_evaluator.h"
|
#include "../split_evaluator.h"
|
||||||
|
#include "../../common/categorical.h"
|
||||||
#include "../../common/random.h"
|
#include "../../common/random.h"
|
||||||
#include "../../common/hist_util.h"
|
#include "../../common/hist_util.h"
|
||||||
#include "../../data/gradient_index.h"
|
#include "../../data/gradient_index.h"
|
||||||
@ -36,13 +39,13 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
|||||||
int32_t n_threads_ {0};
|
int32_t n_threads_ {0};
|
||||||
FeatureInteractionConstraintHost interaction_constraints_;
|
FeatureInteractionConstraintHost interaction_constraints_;
|
||||||
std::vector<NodeEntry> snode_;
|
std::vector<NodeEntry> snode_;
|
||||||
|
ObjInfo task_;
|
||||||
|
|
||||||
// if sum of statistics for non-missing values in the node
|
// if sum of statistics for non-missing values in the node
|
||||||
// is equal to sum of statistics for all values:
|
// is equal to sum of statistics for all values:
|
||||||
// then - there are no missing values
|
// then - there are no missing values
|
||||||
// else - there are missing values
|
// else - there are missing values
|
||||||
bool static SplitContainsMissingValues(const GradStats e,
|
bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) {
|
||||||
const NodeEntry &snode) {
|
|
||||||
if (e.GetGrad() == snode.stats.GetGrad() &&
|
if (e.GetGrad() == snode.stats.GetGrad() &&
|
||||||
e.GetHess() == snode.stats.GetHess()) {
|
e.GetHess() == snode.stats.GetHess()) {
|
||||||
return false;
|
return false;
|
||||||
@ -50,38 +53,40 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
enum SplitType { kNum = 0, kOneHot = 1, kPart = 2 };
|
||||||
|
|
||||||
// Enumerate/Scan the split values of specific feature
|
// Enumerate/Scan the split values of specific feature
|
||||||
// Returns the sum of gradients corresponding to the data points that contains
|
// Returns the sum of gradients corresponding to the data points that contains
|
||||||
// a non-missing value for the particular feature fid.
|
// a non-missing value for the particular feature fid.
|
||||||
template <int d_step>
|
template <int d_step, SplitType split_type>
|
||||||
GradStats EnumerateSplit(
|
GradStats EnumerateSplit(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
|
||||||
common::HistogramCuts const &cut, const common::GHistRow<GradientSumT> &hist,
|
const common::GHistRow<GradientSumT> &hist, bst_feature_t fidx,
|
||||||
const NodeEntry &snode, SplitEntry *p_best, bst_feature_t fidx,
|
bst_node_t nidx,
|
||||||
bst_node_t nidx,
|
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
|
||||||
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator) const {
|
SplitEntry *p_best) const {
|
||||||
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
|
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
|
||||||
|
|
||||||
// aliases
|
// aliases
|
||||||
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
|
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
|
||||||
const std::vector<bst_float> &cut_val = cut.Values();
|
const std::vector<bst_float> &cut_val = cut.Values();
|
||||||
|
auto const &parent = snode_[nidx];
|
||||||
|
int32_t n_bins{static_cast<int32_t>(cut_ptr.at(fidx + 1) - cut_ptr[fidx])};
|
||||||
|
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
||||||
|
|
||||||
// statistics on both sides of split
|
// statistics on both sides of split
|
||||||
GradStats c;
|
GradStats left_sum;
|
||||||
GradStats e;
|
GradStats right_sum;
|
||||||
// best split so far
|
// best split so far
|
||||||
SplitEntry best;
|
SplitEntry best;
|
||||||
|
|
||||||
// bin boundaries
|
// bin boundaries
|
||||||
CHECK_LE(cut_ptr[fidx],
|
CHECK_LE(cut_ptr[fidx], static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
CHECK_LE(cut_ptr[fidx + 1], static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||||
CHECK_LE(cut_ptr[fidx + 1],
|
// imin: index (offset) of the minimum value for feature fid need this for backward
|
||||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
// enumeration
|
||||||
// imin: index (offset) of the minimum value for feature fid
|
|
||||||
// need this for backward enumeration
|
|
||||||
const auto imin = static_cast<int32_t>(cut_ptr[fidx]);
|
const auto imin = static_cast<int32_t>(cut_ptr[fidx]);
|
||||||
// ibegin, iend: smallest/largest cut points for feature fid
|
// ibegin, iend: smallest/largest cut points for feature fid use int to allow for
|
||||||
// use int to allow for value -1
|
// value -1
|
||||||
int32_t ibegin, iend;
|
int32_t ibegin, iend;
|
||||||
if (d_step > 0) {
|
if (d_step > 0) {
|
||||||
ibegin = static_cast<int32_t>(cut_ptr[fidx]);
|
ibegin = static_cast<int32_t>(cut_ptr[fidx]);
|
||||||
@ -91,49 +96,118 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
|||||||
iend = static_cast<int32_t>(cut_ptr[fidx]) - 1;
|
iend = static_cast<int32_t>(cut_ptr[fidx]) - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto calc_bin_value = [&](auto i) {
|
||||||
|
switch (split_type) {
|
||||||
|
case kNum: {
|
||||||
|
left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
|
||||||
|
right_sum.SetSubstract(parent.stats, left_sum);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kOneHot: {
|
||||||
|
// not-chosen categories go to left
|
||||||
|
right_sum = GradStats{hist[i]};
|
||||||
|
left_sum.SetSubstract(parent.stats, right_sum);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kPart: {
|
||||||
|
auto j = d_step == 1 ? (i - ibegin) : (ibegin - i);
|
||||||
|
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
|
||||||
|
left_sum.SetSubstract(parent.stats, right_sum);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
std::terminate();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
int32_t best_thresh{-1};
|
||||||
for (int32_t i = ibegin; i != iend; i += d_step) {
|
for (int32_t i = ibegin; i != iend; i += d_step) {
|
||||||
// start working
|
// start working
|
||||||
// try to find a split
|
// try to find a split
|
||||||
e.Add(hist[i].GetGrad(), hist[i].GetHess());
|
calc_bin_value(i);
|
||||||
if (e.GetHess() >= param_.min_child_weight) {
|
bool improved{false};
|
||||||
c.SetSubstract(snode.stats, e);
|
if (left_sum.GetHess() >= param_.min_child_weight &&
|
||||||
if (c.GetHess() >= param_.min_child_weight) {
|
right_sum.GetHess() >= param_.min_child_weight) {
|
||||||
bst_float loss_chg;
|
bst_float loss_chg;
|
||||||
bst_float split_pt;
|
bst_float split_pt;
|
||||||
if (d_step > 0) {
|
if (d_step > 0) {
|
||||||
// forward enumeration: split at right bound of each bin
|
// forward enumeration: split at right bound of each bin
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg =
|
||||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{e},
|
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum},
|
||||||
GradStats{c}) -
|
GradStats{right_sum}) -
|
||||||
snode.root_gain);
|
parent.root_gain);
|
||||||
split_pt = cut_val[i];
|
split_pt = cut_val[i];
|
||||||
best.Update(loss_chg, fidx, split_pt, d_step == -1, e, c);
|
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
|
||||||
} else {
|
left_sum, right_sum);
|
||||||
// backward enumeration: split at left bound of each bin
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
// backward enumeration: split at left bound of each bin
|
||||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{c},
|
loss_chg =
|
||||||
GradStats{e}) -
|
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{right_sum},
|
||||||
snode.root_gain);
|
GradStats{left_sum}) -
|
||||||
if (i == imin) {
|
parent.root_gain);
|
||||||
// for leftmost bin, left bound is the smallest feature value
|
switch (split_type) {
|
||||||
split_pt = cut.MinValues()[fidx];
|
case kNum: {
|
||||||
} else {
|
if (i == imin) {
|
||||||
split_pt = cut_val[i - 1];
|
split_pt = cut.MinValues()[fidx];
|
||||||
|
} else {
|
||||||
|
split_pt = cut_val[i - 1];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kOneHot: {
|
||||||
|
split_pt = cut_val[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kPart: {
|
||||||
|
split_pt = cut_val[i];
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
best.Update(loss_chg, fidx, split_pt, d_step == -1, c, e);
|
|
||||||
}
|
}
|
||||||
|
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
|
||||||
|
right_sum, left_sum);
|
||||||
}
|
}
|
||||||
|
if (improved) {
|
||||||
|
best_thresh = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (split_type == kPart && best_thresh != -1) {
|
||||||
|
auto n = common::CatBitField::ComputeStorageSize(n_bins);
|
||||||
|
best.cat_bits.resize(n, 0);
|
||||||
|
common::CatBitField cat_bits{best.cat_bits};
|
||||||
|
|
||||||
|
if (d_step == 1) {
|
||||||
|
std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1),
|
||||||
|
[&cat_bits](size_t c) { cat_bits.Set(c); });
|
||||||
|
} else {
|
||||||
|
std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh),
|
||||||
|
[&cat_bits](size_t c) { cat_bits.Set(c); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p_best->Update(best);
|
p_best->Update(best);
|
||||||
|
|
||||||
return e;
|
switch (split_type) {
|
||||||
|
case kNum:
|
||||||
|
// Normal, accumulated to left
|
||||||
|
return left_sum;
|
||||||
|
case kOneHot:
|
||||||
|
// Doesn't matter, not accumulating.
|
||||||
|
return {};
|
||||||
|
case kPart:
|
||||||
|
// Accumulated to right due to chosen cats go to right.
|
||||||
|
return right_sum;
|
||||||
|
}
|
||||||
|
return left_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void EvaluateSplits(const common::HistCollection<GradientSumT> &hist,
|
void EvaluateSplits(const common::HistCollection<GradientSumT> &hist,
|
||||||
common::HistogramCuts const &cut, const RegTree &tree,
|
common::HistogramCuts const &cut,
|
||||||
std::vector<ExpandEntry>* p_entries) {
|
common::Span<FeatureType const> feature_types,
|
||||||
|
const RegTree &tree,
|
||||||
|
std::vector<ExpandEntry> *p_entries) {
|
||||||
auto& entries = *p_entries;
|
auto& entries = *p_entries;
|
||||||
// All nodes are on the same level, so we can store the shared ptr.
|
// All nodes are on the same level, so we can store the shared ptr.
|
||||||
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(
|
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(
|
||||||
@ -150,7 +224,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
|||||||
return features[nidx_in_set]->Size();
|
return features[nidx_in_set]->Size();
|
||||||
}, grain_size);
|
}, grain_size);
|
||||||
|
|
||||||
std::vector<ExpandEntry> tloc_candidates(omp_get_max_threads() * entries.size());
|
std::vector<ExpandEntry> tloc_candidates(n_threads_ * entries.size());
|
||||||
for (size_t i = 0; i < entries.size(); ++i) {
|
for (size_t i = 0; i < entries.size(); ++i) {
|
||||||
for (decltype(n_threads_) j = 0; j < n_threads_; ++j) {
|
for (decltype(n_threads_) j = 0; j < n_threads_; ++j) {
|
||||||
tloc_candidates[i * n_threads_ + j] = entries[i];
|
tloc_candidates[i * n_threads_ + j] = entries[i];
|
||||||
@ -167,12 +241,37 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
|||||||
auto features_set = features[nidx_in_set]->ConstHostSpan();
|
auto features_set = features[nidx_in_set]->ConstHostSpan();
|
||||||
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
|
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
|
||||||
auto fidx = features_set[fidx_in_set];
|
auto fidx = features_set[fidx_in_set];
|
||||||
if (interaction_constraints_.Query(nidx, fidx)) {
|
bool is_cat = common::IsCat(feature_types, fidx);
|
||||||
auto grad_stats = EnumerateSplit<+1>(cut, histogram, snode_[nidx],
|
if (!interaction_constraints_.Query(nidx, fidx)) {
|
||||||
best, fidx, nidx, evaluator);
|
continue;
|
||||||
|
}
|
||||||
|
if (is_cat) {
|
||||||
|
auto n_bins = cut.Ptrs().at(fidx + 1) - cut.Ptrs()[fidx];
|
||||||
|
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) {
|
||||||
|
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||||
|
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||||
|
} else {
|
||||||
|
auto const &cut_ptr = cut.Ptrs();
|
||||||
|
std::vector<size_t> sorted_idx(n_bins);
|
||||||
|
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
|
||||||
|
auto feat_hist = histogram.subspan(cut_ptr[fidx], n_bins);
|
||||||
|
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
|
||||||
|
auto ret = evaluator.CalcWeightCat(param_, feat_hist[l]) <
|
||||||
|
evaluator.CalcWeightCat(param_, feat_hist[r]);
|
||||||
|
static_assert(std::is_same<decltype(ret), bool>::value, "");
|
||||||
|
return ret;
|
||||||
|
});
|
||||||
|
auto grad_stats =
|
||||||
|
EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||||
|
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
|
||||||
|
EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto grad_stats =
|
||||||
|
EnumerateSplit<+1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||||
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
|
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
|
||||||
EnumerateSplit<-1>(cut, histogram, snode_[nidx], best, fidx, nidx,
|
EnumerateSplit<-1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||||
evaluator);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -187,7 +286,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Add splits to tree, handles all statistic
|
// Add splits to tree, handles all statistic
|
||||||
void ApplyTreeSplit(ExpandEntry candidate, RegTree *p_tree) {
|
void ApplyTreeSplit(ExpandEntry const& candidate, RegTree *p_tree) {
|
||||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||||
RegTree &tree = *p_tree;
|
RegTree &tree = *p_tree;
|
||||||
|
|
||||||
@ -201,13 +300,31 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
|||||||
auto right_weight = evaluator.CalcWeight(
|
auto right_weight = evaluator.CalcWeight(
|
||||||
candidate.nid, param_, GradStats{candidate.split.right_sum});
|
candidate.nid, param_, GradStats{candidate.split.right_sum});
|
||||||
|
|
||||||
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(),
|
if (candidate.split.is_cat) {
|
||||||
candidate.split.split_value, candidate.split.DefaultLeft(),
|
std::vector<uint32_t> split_cats;
|
||||||
base_weight, left_weight * param_.learning_rate,
|
if (candidate.split.cat_bits.empty()) {
|
||||||
right_weight * param_.learning_rate,
|
CHECK_LT(candidate.split.split_value, std::numeric_limits<bst_cat_t>::max())
|
||||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
<< "Categorical feature value too large.";
|
||||||
candidate.split.left_sum.GetHess(),
|
auto cat = common::AsCat(candidate.split.split_value);
|
||||||
candidate.split.right_sum.GetHess());
|
split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
|
||||||
|
LBitField32 cat_bits;
|
||||||
|
cat_bits = LBitField32(split_cats);
|
||||||
|
cat_bits.Set(cat);
|
||||||
|
} else {
|
||||||
|
split_cats = candidate.split.cat_bits;
|
||||||
|
}
|
||||||
|
|
||||||
|
tree.ExpandCategorical(
|
||||||
|
candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(),
|
||||||
|
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
|
||||||
|
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||||
|
} else {
|
||||||
|
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
|
||||||
|
candidate.split.DefaultLeft(), base_weight,
|
||||||
|
left_weight * param_.learning_rate, right_weight * param_.learning_rate,
|
||||||
|
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||||
|
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||||
|
}
|
||||||
|
|
||||||
// Set up child constraints
|
// Set up child constraints
|
||||||
auto left_child = tree[candidate.nid].LeftChild();
|
auto left_child = tree[candidate.nid].LeftChild();
|
||||||
@ -249,14 +366,14 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
|||||||
public:
|
public:
|
||||||
// The column sampler must be constructed by caller since we need to preserve the rng
|
// The column sampler must be constructed by caller since we need to preserve the rng
|
||||||
// for the entire training session.
|
// for the entire training session.
|
||||||
explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info,
|
explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info, int32_t n_threads,
|
||||||
int32_t n_threads,
|
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task,
|
||||||
std::shared_ptr<common::ColumnSampler> sampler,
|
|
||||||
bool skip_0_index = false)
|
bool skip_0_index = false)
|
||||||
: param_{param}, column_sampler_{std::move(sampler)},
|
: param_{param},
|
||||||
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_),
|
column_sampler_{std::move(sampler)},
|
||||||
GenericParameter::kCpuId},
|
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), GenericParameter::kCpuId},
|
||||||
n_threads_{n_threads} {
|
n_threads_{n_threads},
|
||||||
|
task_{task} {
|
||||||
interaction_constraints_.Configure(param, info.num_col_);
|
interaction_constraints_.Configure(param, info.num_col_);
|
||||||
column_sampler_->Init(info.num_col_, info.feature_weigths.HostVector(),
|
column_sampler_->Init(info.num_col_, info.feature_weigths.HostVector(),
|
||||||
param_.colsample_bynode, param_.colsample_bylevel,
|
param_.colsample_bynode, param_.colsample_bylevel,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2014-2019 by Contributors
|
* Copyright 2014-2021 by Contributors
|
||||||
* \file param.h
|
* \file param.h
|
||||||
* \brief training parameters, statistics used to support tree construction.
|
* \brief training parameters, statistics used to support tree construction.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -7,6 +7,7 @@
|
|||||||
#ifndef XGBOOST_TREE_PARAM_H_
|
#ifndef XGBOOST_TREE_PARAM_H_
|
||||||
#define XGBOOST_TREE_PARAM_H_
|
#define XGBOOST_TREE_PARAM_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
@ -15,6 +16,7 @@
|
|||||||
|
|
||||||
#include "xgboost/parameter.h"
|
#include "xgboost/parameter.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
#include "../common/categorical.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -36,6 +38,8 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
|||||||
enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 };
|
enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 };
|
||||||
int grow_policy;
|
int grow_policy;
|
||||||
|
|
||||||
|
uint32_t max_cat_to_onehot{1};
|
||||||
|
|
||||||
//----- the rest parameters are less important ----
|
//----- the rest parameters are less important ----
|
||||||
// minimum amount of hessian(weight) allowed in a child
|
// minimum amount of hessian(weight) allowed in a child
|
||||||
float min_child_weight;
|
float min_child_weight;
|
||||||
@ -119,6 +123,10 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
|||||||
"Tree growing policy. 0: favor splitting at nodes closest to the node, "
|
"Tree growing policy. 0: favor splitting at nodes closest to the node, "
|
||||||
"i.e. grow depth-wise. 1: favor splitting at nodes with highest loss "
|
"i.e. grow depth-wise. 1: favor splitting at nodes with highest loss "
|
||||||
"change. (cf. LightGBM)");
|
"change. (cf. LightGBM)");
|
||||||
|
DMLC_DECLARE_FIELD(max_cat_to_onehot)
|
||||||
|
.set_default(4)
|
||||||
|
.set_lower_bound(1)
|
||||||
|
.describe("Maximum number of categories to use one-hot encoding based split.");
|
||||||
DMLC_DECLARE_FIELD(min_child_weight)
|
DMLC_DECLARE_FIELD(min_child_weight)
|
||||||
.set_lower_bound(0.0f)
|
.set_lower_bound(0.0f)
|
||||||
.set_default(1.0f)
|
.set_default(1.0f)
|
||||||
@ -384,6 +392,8 @@ struct SplitEntryContainer {
|
|||||||
/*! \brief split index */
|
/*! \brief split index */
|
||||||
bst_feature_t sindex{0};
|
bst_feature_t sindex{0};
|
||||||
bst_float split_value{0.0f};
|
bst_float split_value{0.0f};
|
||||||
|
std::vector<uint32_t> cat_bits;
|
||||||
|
bool is_cat{false};
|
||||||
|
|
||||||
GradientT left_sum;
|
GradientT left_sum;
|
||||||
GradientT right_sum;
|
GradientT right_sum;
|
||||||
@ -433,6 +443,8 @@ struct SplitEntryContainer {
|
|||||||
this->loss_chg = e.loss_chg;
|
this->loss_chg = e.loss_chg;
|
||||||
this->sindex = e.sindex;
|
this->sindex = e.sindex;
|
||||||
this->split_value = e.split_value;
|
this->split_value = e.split_value;
|
||||||
|
this->is_cat = e.is_cat;
|
||||||
|
this->cat_bits = e.cat_bits;
|
||||||
this->left_sum = e.left_sum;
|
this->left_sum = e.left_sum;
|
||||||
this->right_sum = e.right_sum;
|
this->right_sum = e.right_sum;
|
||||||
return true;
|
return true;
|
||||||
@ -449,9 +461,8 @@ struct SplitEntryContainer {
|
|||||||
* \return whether the proposed split is better and can replace current split
|
* \return whether the proposed split is better and can replace current split
|
||||||
*/
|
*/
|
||||||
bool Update(bst_float new_loss_chg, unsigned split_index,
|
bool Update(bst_float new_loss_chg, unsigned split_index,
|
||||||
bst_float new_split_value, bool default_left,
|
bst_float new_split_value, bool default_left, bool is_cat,
|
||||||
const GradientT &left_sum,
|
const GradientT &left_sum, const GradientT &right_sum) {
|
||||||
const GradientT &right_sum) {
|
|
||||||
if (this->NeedReplace(new_loss_chg, split_index)) {
|
if (this->NeedReplace(new_loss_chg, split_index)) {
|
||||||
this->loss_chg = new_loss_chg;
|
this->loss_chg = new_loss_chg;
|
||||||
if (default_left) {
|
if (default_left) {
|
||||||
@ -459,6 +470,31 @@ struct SplitEntryContainer {
|
|||||||
}
|
}
|
||||||
this->sindex = split_index;
|
this->sindex = split_index;
|
||||||
this->split_value = new_split_value;
|
this->split_value = new_split_value;
|
||||||
|
this->is_cat = is_cat;
|
||||||
|
this->left_sum = left_sum;
|
||||||
|
this->right_sum = right_sum;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Update with partition based categorical split.
|
||||||
|
*
|
||||||
|
* \return Whether the proposed split is better and can replace current split.
|
||||||
|
*/
|
||||||
|
bool Update(float new_loss_chg, bst_feature_t split_index, common::KCatBitField cats,
|
||||||
|
bool default_left, GradientT const &left_sum, GradientT const &right_sum) {
|
||||||
|
if (this->NeedReplace(new_loss_chg, split_index)) {
|
||||||
|
this->loss_chg = new_loss_chg;
|
||||||
|
if (default_left) {
|
||||||
|
split_index |= (1U << 31);
|
||||||
|
}
|
||||||
|
this->sindex = split_index;
|
||||||
|
cat_bits.resize(cats.Bits().size());
|
||||||
|
std::copy(cats.Bits().begin(), cats.Bits().end(), cat_bits.begin());
|
||||||
|
this->is_cat = true;
|
||||||
this->left_sum = left_sum;
|
this->left_sum = left_sum;
|
||||||
this->right_sum = right_sum;
|
this->right_sum = right_sum;
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
@ -92,7 +92,7 @@ class TreeEvaluator {
|
|||||||
|
|
||||||
XGBOOST_DEVICE float CalcWeight(bst_node_t nodeid, const ParamT ¶m,
|
XGBOOST_DEVICE float CalcWeight(bst_node_t nodeid, const ParamT ¶m,
|
||||||
tree::GradStats const& stats) const {
|
tree::GradStats const& stats) const {
|
||||||
float w = xgboost::tree::CalcWeight(param, stats);
|
float w = ::xgboost::tree::CalcWeight(param, stats);
|
||||||
if (!has_constraint) {
|
if (!has_constraint) {
|
||||||
return w;
|
return w;
|
||||||
}
|
}
|
||||||
@ -107,6 +107,12 @@ class TreeEvaluator {
|
|||||||
return w;
|
return w;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename GradientSumT>
|
||||||
|
XGBOOST_DEVICE double CalcWeightCat(ParamT const& param, GradientSumT const& stats) const {
|
||||||
|
return ::xgboost::tree::CalcWeight(param, stats);
|
||||||
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE float
|
XGBOOST_DEVICE float
|
||||||
CalcGainGivenWeight(ParamT const &p, tree::GradStats const& stats, float w) const {
|
CalcGainGivenWeight(ParamT const &p, tree::GradStats const& stats, float w) const {
|
||||||
if (stats.GetHess() <= 0) {
|
if (stats.GetHess() <= 0) {
|
||||||
|
|||||||
@ -336,10 +336,10 @@ class ColMaker: public TreeUpdater {
|
|||||||
bst_float proposed_split = (fvalue + e.last_fvalue) * 0.5f;
|
bst_float proposed_split = (fvalue + e.last_fvalue) * 0.5f;
|
||||||
if ( proposed_split == fvalue ) {
|
if ( proposed_split == fvalue ) {
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue,
|
e.best.Update(loss_chg, fid, e.last_fvalue,
|
||||||
d_step == -1, c, e.stats);
|
d_step == -1, false, c, e.stats);
|
||||||
} else {
|
} else {
|
||||||
e.best.Update(loss_chg, fid, proposed_split,
|
e.best.Update(loss_chg, fid, proposed_split,
|
||||||
d_step == -1, c, e.stats);
|
d_step == -1, false, c, e.stats);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
@ -348,10 +348,10 @@ class ColMaker: public TreeUpdater {
|
|||||||
bst_float proposed_split = (fvalue + e.last_fvalue) * 0.5f;
|
bst_float proposed_split = (fvalue + e.last_fvalue) * 0.5f;
|
||||||
if ( proposed_split == fvalue ) {
|
if ( proposed_split == fvalue ) {
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue,
|
e.best.Update(loss_chg, fid, e.last_fvalue,
|
||||||
d_step == -1, e.stats, c);
|
d_step == -1, false, e.stats, c);
|
||||||
} else {
|
} else {
|
||||||
e.best.Update(loss_chg, fid, proposed_split,
|
e.best.Update(loss_chg, fid, proposed_split,
|
||||||
d_step == -1, e.stats, c);
|
d_step == -1, false, e.stats, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -430,14 +430,14 @@ class ColMaker: public TreeUpdater {
|
|||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
evaluator.CalcSplitGain(param_, nid, fid, c, e.stats) -
|
evaluator.CalcSplitGain(param_, nid, fid, c, e.stats) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, c,
|
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1,
|
||||||
e.stats);
|
false, c, e.stats);
|
||||||
} else {
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
evaluator.CalcSplitGain(param_, nid, fid, e.stats, c) -
|
evaluator.CalcSplitGain(param_, nid, fid, e.stats, c) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1,
|
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1,
|
||||||
e.stats, c);
|
false, e.stats, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -173,7 +173,8 @@ class HistMaker: public BaseMaker {
|
|||||||
if (c.sum_hess >= param_.min_child_weight) {
|
if (c.sum_hess >= param_.min_child_weight) {
|
||||||
double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) +
|
double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) +
|
||||||
CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain;
|
CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain;
|
||||||
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i], false, s, c)) {
|
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i],
|
||||||
|
false, false, s, c)) {
|
||||||
*left_sum = s;
|
*left_sum = s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -187,7 +188,8 @@ class HistMaker: public BaseMaker {
|
|||||||
if (c.sum_hess >= param_.min_child_weight) {
|
if (c.sum_hess >= param_.min_child_weight) {
|
||||||
double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) +
|
double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) +
|
||||||
CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain;
|
CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain;
|
||||||
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true, c, s)) {
|
if (best->Update(static_cast<bst_float>(loss_chg), fid,
|
||||||
|
hist.cut[i - 1], true, false, c, s)) {
|
||||||
*left_sum = c;
|
*left_sum = c;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -168,9 +168,11 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
|||||||
|
|
||||||
std::vector<CPUExpandEntry> entries{node};
|
std::vector<CPUExpandEntry> entries{node};
|
||||||
builder_monitor_.Start("EvaluateSplits");
|
builder_monitor_.Start("EvaluateSplits");
|
||||||
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(
|
||||||
BatchParam{GenericParameter::kCpuId, param_.max_bin})) {
|
BatchParam{GenericParameter::kCpuId, param_.max_bin})) {
|
||||||
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, *p_tree, &entries);
|
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft,
|
||||||
|
*p_tree, &entries);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
builder_monitor_.Stop("EvaluateSplits");
|
builder_monitor_.Stop("EvaluateSplits");
|
||||||
@ -272,8 +274,9 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
|||||||
}
|
}
|
||||||
|
|
||||||
builder_monitor_.Start("EvaluateSplits");
|
builder_monitor_.Start("EvaluateSplits");
|
||||||
evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), gmat.cut,
|
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||||
*p_tree, &nodes_to_evaluate);
|
evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(),
|
||||||
|
gmat.cut, ft, *p_tree, &nodes_to_evaluate);
|
||||||
builder_monitor_.Stop("EvaluateSplits");
|
builder_monitor_.Stop("EvaluateSplits");
|
||||||
|
|
||||||
for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) {
|
for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) {
|
||||||
@ -529,11 +532,11 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
|
|||||||
// store a pointer to the tree
|
// store a pointer to the tree
|
||||||
p_last_tree_ = &tree;
|
p_last_tree_ = &tree;
|
||||||
if (data_layout_ == DataLayout::kDenseDataOneBased) {
|
if (data_layout_ == DataLayout::kDenseDataOneBased) {
|
||||||
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{param_, info, this->nthread_,
|
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||||
column_sampler_, true});
|
param_, info, this->nthread_, column_sampler_, task_, true});
|
||||||
} else {
|
} else {
|
||||||
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{param_, info, this->nthread_,
|
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||||
column_sampler_, false});
|
param_, info, this->nthread_, column_sampler_, task_, false});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data_layout_ == DataLayout::kDenseDataZeroBased
|
if (data_layout_ == DataLayout::kDenseDataZeroBased
|
||||||
|
|||||||
44
tests/cpp/categorical_helpers.h
Normal file
44
tests/cpp/categorical_helpers.h
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 by XGBoost Contributors
|
||||||
|
*
|
||||||
|
* \brief Utilities for testing categorical data support.
|
||||||
|
*/
|
||||||
|
#include <numeric>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "xgboost/span.h"
|
||||||
|
#include "helpers.h"
|
||||||
|
#include "../../src/common/categorical.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
inline std::vector<float> OneHotEncodeFeature(std::vector<float> x,
|
||||||
|
size_t num_cat) {
|
||||||
|
std::vector<float> ret(x.size() * num_cat, 0);
|
||||||
|
size_t n_rows = x.size();
|
||||||
|
for (size_t r = 0; r < n_rows; ++r) {
|
||||||
|
bst_cat_t cat = common::AsCat(x[r]);
|
||||||
|
ret.at(num_cat * r + cat) = 1;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename GradientSumT>
|
||||||
|
void ValidateCategoricalHistogram(size_t n_categories,
|
||||||
|
common::Span<GradientSumT> onehot,
|
||||||
|
common::Span<GradientSumT> cat) {
|
||||||
|
auto cat_sum = std::accumulate(cat.cbegin(), cat.cend(), GradientPairPrecise{});
|
||||||
|
for (size_t c = 0; c < n_categories; ++c) {
|
||||||
|
auto zero = onehot[c * 2];
|
||||||
|
auto one = onehot[c * 2 + 1];
|
||||||
|
|
||||||
|
auto chosen = cat[c];
|
||||||
|
auto not_chosen = cat_sum - chosen;
|
||||||
|
|
||||||
|
ASSERT_LE(RelError(zero.GetGrad(), not_chosen.GetGrad()), kRtEps);
|
||||||
|
ASSERT_LE(RelError(zero.GetHess(), not_chosen.GetHess()), kRtEps);
|
||||||
|
|
||||||
|
ASSERT_LE(RelError(one.GetGrad(), chosen.GetGrad()), kRtEps);
|
||||||
|
ASSERT_LE(RelError(one.GetHess(), chosen.GetHess()), kRtEps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
@ -5,6 +5,13 @@
|
|||||||
#include "../../../src/common/quantile.cuh"
|
#include "../../../src/common/quantile.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
namespace {
|
||||||
|
struct IsSorted {
|
||||||
|
XGBOOST_DEVICE bool operator()(common::SketchEntry const& a, common::SketchEntry const& b) const {
|
||||||
|
return a.value < b.value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
namespace common {
|
namespace common {
|
||||||
TEST(GPUQuantile, Basic) {
|
TEST(GPUQuantile, Basic) {
|
||||||
constexpr size_t kRows = 1000, kCols = 100, kBins = 256;
|
constexpr size_t kRows = 1000, kCols = 100, kBins = 256;
|
||||||
@ -52,9 +59,15 @@ void TestSketchUnique(float sparsity) {
|
|||||||
ASSERT_EQ(sketch.Data().size(), h_columns_ptr.back());
|
ASSERT_EQ(sketch.Data().size(), h_columns_ptr.back());
|
||||||
|
|
||||||
sketch.Unique();
|
sketch.Unique();
|
||||||
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(),
|
|
||||||
sketch.Data().data() + sketch.Data().size(),
|
std::vector<SketchEntry> h_data(sketch.Data().size());
|
||||||
detail::SketchUnique{}));
|
thrust::copy(dh::tcbegin(sketch.Data()), dh::tcend(sketch.Data()), h_data.begin());
|
||||||
|
|
||||||
|
for (size_t i = 1; i < h_columns_ptr.size(); ++i) {
|
||||||
|
auto begin = h_columns_ptr[i - 1];
|
||||||
|
auto column = common::Span<SketchEntry>(h_data).subspan(begin, h_columns_ptr[i] - begin);
|
||||||
|
ASSERT_TRUE(std::is_sorted(column.begin(), column.end(), IsSorted{}));
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,8 +97,7 @@ void TestQuantileElemRank(int32_t device, Span<SketchEntry const> in,
|
|||||||
if (with_error) {
|
if (with_error) {
|
||||||
ASSERT_GE(in_column[idx].rmin + in_column[idx].rmin * kRtEps,
|
ASSERT_GE(in_column[idx].rmin + in_column[idx].rmin * kRtEps,
|
||||||
prev_rmin);
|
prev_rmin);
|
||||||
ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps,
|
ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps, prev_rmax);
|
||||||
prev_rmax);
|
|
||||||
ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps,
|
ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps,
|
||||||
rmin_next);
|
rmin_next);
|
||||||
} else {
|
} else {
|
||||||
@ -169,7 +181,7 @@ TEST(GPUQuantile, MergeEmpty) {
|
|||||||
|
|
||||||
TEST(GPUQuantile, MergeBasic) {
|
TEST(GPUQuantile, MergeBasic) {
|
||||||
constexpr size_t kRows = 1000, kCols = 100;
|
constexpr size_t kRows = 1000, kCols = 100;
|
||||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
|
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) {
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch_0(ft, n_bins, kCols, kRows, 0);
|
SketchContainer sketch_0(ft, n_bins, kCols, kRows, 0);
|
||||||
HostDeviceVector<float> storage_0;
|
HostDeviceVector<float> storage_0;
|
||||||
@ -265,9 +277,16 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
|
|||||||
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
|
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
|
||||||
|
|
||||||
sketch_0.Unique();
|
sketch_0.Unique();
|
||||||
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch_0.Data().data(),
|
columns_ptr = sketch_0.ColumnsPtr();
|
||||||
sketch_0.Data().data() + sketch_0.Data().size(),
|
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
|
||||||
detail::SketchUnique{}));
|
|
||||||
|
std::vector<SketchEntry> h_data(sketch_0.Data().size());
|
||||||
|
dh::CopyDeviceSpanToVector(&h_data, sketch_0.Data());
|
||||||
|
for (size_t i = 1; i < h_columns_ptr.size(); ++i) {
|
||||||
|
auto begin = h_columns_ptr[i - 1];
|
||||||
|
auto column = Span<SketchEntry> {h_data}.subspan(begin, h_columns_ptr[i] - begin);
|
||||||
|
ASSERT_TRUE(std::is_sorted(column.begin(), column.end(), IsSorted{}));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GPUQuantile, MergeDuplicated) {
|
TEST(GPUQuantile, MergeDuplicated) {
|
||||||
|
|||||||
@ -48,7 +48,9 @@ template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
|
|||||||
std::vector<MetaInfo> infos(2);
|
std::vector<MetaInfo> infos(2);
|
||||||
auto& h_weights = infos.front().weights_.HostVector();
|
auto& h_weights = infos.front().weights_.HostVector();
|
||||||
h_weights.resize(rows);
|
h_weights.resize(rows);
|
||||||
std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); });
|
|
||||||
|
SimpleRealUniformDistribution<float> weight_dist(0, 10);
|
||||||
|
std::generate(h_weights.begin(), h_weights.end(), [&]() { return weight_dist(&lcg); });
|
||||||
|
|
||||||
for (auto seed : seeds) {
|
for (auto seed : seeds) {
|
||||||
for (auto n_bin : bins) {
|
for (auto n_bin : bins) {
|
||||||
|
|||||||
@ -172,12 +172,10 @@ SimpleLCG::StateType SimpleLCG::operator()() {
|
|||||||
state_ = (alpha_ * state_) % mod_;
|
state_ = (alpha_ * state_) % mod_;
|
||||||
return state_;
|
return state_;
|
||||||
}
|
}
|
||||||
SimpleLCG::StateType SimpleLCG::Min() const {
|
SimpleLCG::StateType SimpleLCG::Min() const { return min(); }
|
||||||
return seed_ * alpha_;
|
SimpleLCG::StateType SimpleLCG::Max() const { return max(); }
|
||||||
}
|
// Make sure it's compile time constant.
|
||||||
SimpleLCG::StateType SimpleLCG::Max() const {
|
static_assert(SimpleLCG::max() - SimpleLCG::min(), "");
|
||||||
return max_value_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RandomDataGenerator::GenerateDense(HostDeviceVector<float> *out) const {
|
void RandomDataGenerator::GenerateDense(HostDeviceVector<float> *out) const {
|
||||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
|
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
|
||||||
@ -291,6 +289,7 @@ void RandomDataGenerator::GenerateCSR(
|
|||||||
|
|
||||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
|
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
|
||||||
float sparsity = sparsity_ * (upper_ - lower_) + lower_;
|
float sparsity = sparsity_ * (upper_ - lower_) + lower_;
|
||||||
|
SimpleRealUniformDistribution<bst_float> cat(0.0, max_cat_);
|
||||||
|
|
||||||
h_rptr.emplace_back(0);
|
h_rptr.emplace_back(0);
|
||||||
for (size_t i = 0; i < rows_; ++i) {
|
for (size_t i = 0; i < rows_; ++i) {
|
||||||
@ -298,7 +297,11 @@ void RandomDataGenerator::GenerateCSR(
|
|||||||
for (size_t j = 0; j < cols_; ++j) {
|
for (size_t j = 0; j < cols_; ++j) {
|
||||||
auto g = dist(&lcg);
|
auto g = dist(&lcg);
|
||||||
if (g >= sparsity) {
|
if (g >= sparsity) {
|
||||||
g = dist(&lcg);
|
if (common::IsCat(ft_, j)) {
|
||||||
|
g = common::AsCat(cat(&lcg));
|
||||||
|
} else {
|
||||||
|
g = dist(&lcg);
|
||||||
|
}
|
||||||
h_value.emplace_back(g);
|
h_value.emplace_back(g);
|
||||||
rptr++;
|
rptr++;
|
||||||
h_cols.emplace_back(j);
|
h_cols.emplace_back(j);
|
||||||
@ -347,11 +350,15 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
|
|||||||
}
|
}
|
||||||
if (device_ >= 0) {
|
if (device_ >= 0) {
|
||||||
out->Info().labels_.SetDevice(device_);
|
out->Info().labels_.SetDevice(device_);
|
||||||
|
out->Info().feature_types.SetDevice(device_);
|
||||||
for (auto const& page : out->GetBatches<SparsePage>()) {
|
for (auto const& page : out->GetBatches<SparsePage>()) {
|
||||||
page.data.SetDevice(device_);
|
page.data.SetDevice(device_);
|
||||||
page.offset.SetDevice(device_);
|
page.offset.SetDevice(device_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (!ft_.empty()) {
|
||||||
|
out->Info().feature_types.HostVector() = ft_;
|
||||||
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -106,42 +106,39 @@ bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
|
|||||||
*/
|
*/
|
||||||
class SimpleLCG {
|
class SimpleLCG {
|
||||||
private:
|
private:
|
||||||
using StateType = int64_t;
|
using StateType = uint64_t;
|
||||||
static StateType constexpr kDefaultInit = 3;
|
static StateType constexpr kDefaultInit = 3;
|
||||||
static StateType constexpr default_alpha_ = 61;
|
static StateType constexpr kDefaultAlpha = 61;
|
||||||
static StateType constexpr max_value_ = ((StateType)1 << 32) - 1;
|
static StateType constexpr kMaxValue = (static_cast<StateType>(1) << 32) - 1;
|
||||||
|
|
||||||
StateType state_;
|
StateType state_;
|
||||||
StateType const alpha_;
|
StateType const alpha_;
|
||||||
StateType const mod_;
|
StateType const mod_;
|
||||||
|
|
||||||
StateType seed_;
|
public:
|
||||||
|
using result_type = StateType; // NOLINT
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SimpleLCG() : state_{kDefaultInit},
|
SimpleLCG() : state_{kDefaultInit}, alpha_{kDefaultAlpha}, mod_{kMaxValue} {}
|
||||||
alpha_{default_alpha_}, mod_{max_value_}, seed_{state_}{}
|
|
||||||
SimpleLCG(SimpleLCG const& that) = default;
|
SimpleLCG(SimpleLCG const& that) = default;
|
||||||
SimpleLCG(SimpleLCG&& that) = default;
|
SimpleLCG(SimpleLCG&& that) = default;
|
||||||
|
|
||||||
void Seed(StateType seed) {
|
void Seed(StateType seed) { state_ = seed % mod_; }
|
||||||
seed_ = seed;
|
|
||||||
}
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Initialize SimpleLCG.
|
* \brief Initialize SimpleLCG.
|
||||||
*
|
*
|
||||||
* \param state Initial state, can also be considered as seed. If set to
|
* \param state Initial state, can also be considered as seed. If set to
|
||||||
* zero, SimpleLCG will use internal default value.
|
* zero, SimpleLCG will use internal default value.
|
||||||
* \param alpha multiplier
|
|
||||||
* \param mod modulo
|
|
||||||
*/
|
*/
|
||||||
explicit SimpleLCG(StateType state,
|
explicit SimpleLCG(StateType state)
|
||||||
StateType alpha=default_alpha_, StateType mod=max_value_)
|
: state_{state == 0 ? kDefaultInit : state}, alpha_{kDefaultAlpha}, mod_{kMaxValue} {}
|
||||||
: state_{state == 0 ? kDefaultInit : state},
|
|
||||||
alpha_{alpha}, mod_{mod} , seed_{state} {}
|
|
||||||
|
|
||||||
StateType operator()();
|
StateType operator()();
|
||||||
StateType Min() const;
|
StateType Min() const;
|
||||||
StateType Max() const;
|
StateType Max() const;
|
||||||
|
|
||||||
|
constexpr result_type static min() { return 0; }; // NOLINT
|
||||||
|
constexpr result_type static max() { return kMaxValue; } // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ResultT>
|
template <typename ResultT>
|
||||||
@ -217,10 +214,12 @@ class RandomDataGenerator {
|
|||||||
float upper_;
|
float upper_;
|
||||||
|
|
||||||
int32_t device_;
|
int32_t device_;
|
||||||
int32_t seed_;
|
uint64_t seed_;
|
||||||
SimpleLCG lcg_;
|
SimpleLCG lcg_;
|
||||||
|
|
||||||
size_t bins_;
|
size_t bins_;
|
||||||
|
std::vector<FeatureType> ft_;
|
||||||
|
bst_cat_t max_cat_;
|
||||||
|
|
||||||
Json ArrayInterfaceImpl(HostDeviceVector<float> *storage, size_t rows,
|
Json ArrayInterfaceImpl(HostDeviceVector<float> *storage, size_t rows,
|
||||||
size_t cols) const;
|
size_t cols) const;
|
||||||
@ -242,7 +241,7 @@ class RandomDataGenerator {
|
|||||||
device_ = d;
|
device_ = d;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
RandomDataGenerator& Seed(int32_t s) {
|
RandomDataGenerator& Seed(uint64_t s) {
|
||||||
seed_ = s;
|
seed_ = s;
|
||||||
lcg_.Seed(seed_);
|
lcg_.Seed(seed_);
|
||||||
return *this;
|
return *this;
|
||||||
@ -251,6 +250,16 @@ class RandomDataGenerator {
|
|||||||
bins_ = b;
|
bins_ = b;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
RandomDataGenerator& Type(common::Span<FeatureType> ft) {
|
||||||
|
CHECK_EQ(ft.size(), cols_);
|
||||||
|
ft_.resize(ft.size());
|
||||||
|
std::copy(ft.cbegin(), ft.cend(), ft_.begin());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
RandomDataGenerator& MaxCategory(bst_cat_t cat) {
|
||||||
|
max_cat_ = cat;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
void GenerateDense(HostDeviceVector<float>* out) const;
|
void GenerateDense(HostDeviceVector<float>* out) const;
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../../helpers.h"
|
|
||||||
#include "../../../../src/common/categorical.h"
|
#include "../../../../src/common/categorical.h"
|
||||||
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
|
||||||
#include "../../../../src/tree/gpu_hist/histogram.cuh"
|
#include "../../../../src/tree/gpu_hist/histogram.cuh"
|
||||||
|
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
||||||
|
#include "../../categorical_helpers.h"
|
||||||
|
#include "../../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -99,16 +101,6 @@ TEST(Histogram, GPUDeterministic) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> OneHotEncodeFeature(std::vector<float> x, size_t num_cat) {
|
|
||||||
std::vector<float> ret(x.size() * num_cat, 0);
|
|
||||||
size_t n_rows = x.size();
|
|
||||||
for (size_t r = 0; r < n_rows; ++r) {
|
|
||||||
bst_cat_t cat = common::AsCat(x[r]);
|
|
||||||
ret.at(num_cat * r + cat) = 1;
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test 1 vs rest categorical histogram is equivalent to one hot encoded data.
|
// Test 1 vs rest categorical histogram is equivalent to one hot encoded data.
|
||||||
void TestGPUHistogramCategorical(size_t num_categories) {
|
void TestGPUHistogramCategorical(size_t num_categories) {
|
||||||
size_t constexpr kRows = 340;
|
size_t constexpr kRows = 340;
|
||||||
@ -123,7 +115,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
||||||
gpair.SetDevice(0);
|
gpair.SetDevice(0);
|
||||||
auto rounding = CreateRoundingFactor<GradientPairPrecise>(gpair.DeviceSpan());
|
auto rounding = CreateRoundingFactor<GradientPairPrecise>(gpair.DeviceSpan());
|
||||||
// Generate hist with cat data.
|
/**
|
||||||
|
* Generate hist with cat data.
|
||||||
|
*/
|
||||||
for (auto const &batch : cat_m->GetBatches<EllpackPage>(batch_param)) {
|
for (auto const &batch : cat_m->GetBatches<EllpackPage>(batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
FeatureGroups single_group(page->Cuts());
|
FeatureGroups single_group(page->Cuts());
|
||||||
@ -133,7 +127,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
rounding);
|
rounding);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate hist with one hot encoded data.
|
/**
|
||||||
|
* Generate hist with one hot encoded data.
|
||||||
|
*/
|
||||||
auto x_encoded = OneHotEncodeFeature(x, num_categories);
|
auto x_encoded = OneHotEncodeFeature(x, num_categories);
|
||||||
auto encode_m = GetDMatrixFromData(x_encoded, kRows, num_categories);
|
auto encode_m = GetDMatrixFromData(x_encoded, kRows, num_categories);
|
||||||
dh::device_vector<GradientPairPrecise> encode_hist(2 * num_categories);
|
dh::device_vector<GradientPairPrecise> encode_hist(2 * num_categories);
|
||||||
@ -152,20 +148,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
|
|
||||||
std::vector<GradientPairPrecise> h_encode_hist(encode_hist.size());
|
std::vector<GradientPairPrecise> h_encode_hist(encode_hist.size());
|
||||||
thrust::copy(encode_hist.begin(), encode_hist.end(), h_encode_hist.begin());
|
thrust::copy(encode_hist.begin(), encode_hist.end(), h_encode_hist.begin());
|
||||||
|
ValidateCategoricalHistogram(num_categories,
|
||||||
for (size_t c = 0; c < num_categories; ++c) {
|
common::Span<GradientPairPrecise>{h_encode_hist},
|
||||||
auto zero = h_encode_hist[c * 2];
|
common::Span<GradientPairPrecise>{h_cat_hist});
|
||||||
auto one = h_encode_hist[c * 2 + 1];
|
|
||||||
|
|
||||||
auto chosen = h_cat_hist[c];
|
|
||||||
auto not_chosen = cat_sum - chosen;
|
|
||||||
|
|
||||||
ASSERT_LE(RelError(zero.GetGrad(), not_chosen.GetGrad()), kRtEps);
|
|
||||||
ASSERT_LE(RelError(zero.GetHess(), not_chosen.GetHess()), kRtEps);
|
|
||||||
|
|
||||||
ASSERT_LE(RelError(one.GetGrad(), chosen.GetGrad()), kRtEps);
|
|
||||||
ASSERT_LE(RelError(one.GetHess(), chosen.GetHess()), kRtEps);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Histogram, GPUHistCategorical) {
|
TEST(Histogram, GPUHistCategorical) {
|
||||||
|
|||||||
@ -7,7 +7,6 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
template <typename GradientSumT> void TestEvaluateSplits() {
|
template <typename GradientSumT> void TestEvaluateSplits() {
|
||||||
int static constexpr kRows = 8, kCols = 16;
|
int static constexpr kRows = 8, kCols = 16;
|
||||||
auto orig = omp_get_max_threads();
|
auto orig = omp_get_max_threads();
|
||||||
@ -16,14 +15,12 @@ template <typename GradientSumT> void TestEvaluateSplits() {
|
|||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
param.UpdateAllowUnknown(Args{{}});
|
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
|
||||||
param.min_child_weight = 0;
|
|
||||||
param.reg_lambda = 0;
|
|
||||||
|
|
||||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
|
||||||
|
|
||||||
auto evaluator =
|
auto evaluator = HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||||
HistEvaluator<GradientSumT, CPUExpandEntry>{param, dmat->Info(), n_threads, sampler};
|
param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}};
|
||||||
common::HistCollection<GradientSumT> hist;
|
common::HistCollection<GradientSumT> hist;
|
||||||
std::vector<GradientPair> row_gpairs = {
|
std::vector<GradientPair> row_gpairs = {
|
||||||
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||||
@ -39,7 +36,7 @@ template <typename GradientSumT> void TestEvaluateSplits() {
|
|||||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||||
row_set_collection.Init();
|
row_set_collection.Init();
|
||||||
|
|
||||||
auto hist_builder = GHistBuilder<GradientSumT>(n_threads, gmat.cut.Ptrs().back());
|
auto hist_builder = GHistBuilder<GradientSumT>(omp_get_max_threads(), gmat.cut.Ptrs().back());
|
||||||
hist.Init(gmat.cut.Ptrs().back());
|
hist.Init(gmat.cut.Ptrs().back());
|
||||||
hist.AddHistRow(0);
|
hist.AddHistRow(0);
|
||||||
hist.AllocateAllData();
|
hist.AllocateAllData();
|
||||||
@ -58,7 +55,7 @@ template <typename GradientSumT> void TestEvaluateSplits() {
|
|||||||
entries.front().depth = 0;
|
entries.front().depth = 0;
|
||||||
|
|
||||||
evaluator.InitRoot(GradStats{total_gpair});
|
evaluator.InitRoot(GradStats{total_gpair});
|
||||||
evaluator.EvaluateSplits(hist, gmat.cut, tree, &entries);
|
evaluator.EvaluateSplits(hist, gmat.cut, {}, tree, &entries);
|
||||||
|
|
||||||
auto best_loss_chg =
|
auto best_loss_chg =
|
||||||
evaluator.Evaluator().CalcSplitGain(
|
evaluator.Evaluator().CalcSplitGain(
|
||||||
@ -96,8 +93,8 @@ TEST(HistEvaluator, Apply) {
|
|||||||
param.UpdateAllowUnknown(Args{{}});
|
param.UpdateAllowUnknown(Args{{}});
|
||||||
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
auto evaluator_ =
|
auto evaluator_ = HistEvaluator<float, CPUExpandEntry>{param, dmat->Info(), 4, sampler,
|
||||||
HistEvaluator<float, CPUExpandEntry>{param, dmat->Info(), 4, sampler};
|
ObjInfo{ObjInfo::kRegression}};
|
||||||
|
|
||||||
CPUExpandEntry entry{0, 0, 10.0f};
|
CPUExpandEntry entry{0, 0, 10.0f};
|
||||||
entry.split.left_sum = GradStats{0.4, 0.6f};
|
entry.split.left_sum = GradStats{0.4, 0.6f};
|
||||||
@ -108,5 +105,142 @@ TEST(HistEvaluator, Apply) {
|
|||||||
ASSERT_EQ(tree.Stat(tree[0].LeftChild()).sum_hess, 0.6f);
|
ASSERT_EQ(tree.Stat(tree[0].LeftChild()).sum_hess, 0.6f);
|
||||||
ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f);
|
ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(HistEvaluator, CategoricalPartition) {
|
||||||
|
int static constexpr kRows = 128, kCols = 1;
|
||||||
|
using GradientSumT = double;
|
||||||
|
std::vector<FeatureType> ft(kCols, FeatureType::kCategorical);
|
||||||
|
|
||||||
|
TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
|
||||||
|
|
||||||
|
size_t n_cats{8};
|
||||||
|
|
||||||
|
auto dmat =
|
||||||
|
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
|
||||||
|
|
||||||
|
int32_t n_threads = 16;
|
||||||
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
|
auto evaluator = HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||||
|
param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}};
|
||||||
|
|
||||||
|
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 32})) {
|
||||||
|
common::HistCollection<GradientSumT> hist;
|
||||||
|
|
||||||
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
|
entries.front().nid = 0;
|
||||||
|
entries.front().depth = 0;
|
||||||
|
|
||||||
|
hist.Init(gmat.cut.TotalBins());
|
||||||
|
hist.AddHistRow(0);
|
||||||
|
hist.AllocateAllData();
|
||||||
|
auto node_hist = hist[0];
|
||||||
|
ASSERT_EQ(node_hist.size(), n_cats);
|
||||||
|
ASSERT_EQ(node_hist.size(), gmat.cut.Ptrs().back());
|
||||||
|
|
||||||
|
GradientPairPrecise total_gpair;
|
||||||
|
for (size_t i = 0; i < node_hist.size(); ++i) {
|
||||||
|
node_hist[i] = {static_cast<double>(node_hist.size() - i), 1.0};
|
||||||
|
total_gpair += node_hist[i];
|
||||||
|
}
|
||||||
|
SimpleLCG lcg;
|
||||||
|
std::shuffle(node_hist.begin(), node_hist.end(), lcg);
|
||||||
|
|
||||||
|
RegTree tree;
|
||||||
|
evaluator.InitRoot(GradStats{total_gpair});
|
||||||
|
evaluator.EvaluateSplits(hist, gmat.cut, ft, tree, &entries);
|
||||||
|
ASSERT_TRUE(entries.front().split.is_cat);
|
||||||
|
|
||||||
|
auto run_eval = [&](auto fn) {
|
||||||
|
for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) {
|
||||||
|
GradStats left, right;
|
||||||
|
for (size_t j = gmat.cut.Ptrs()[i - 1]; j < gmat.cut.Ptrs()[i]; ++j) {
|
||||||
|
auto loss_chg = evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) -
|
||||||
|
evaluator.Stats().front().root_gain;
|
||||||
|
fn(loss_chg);
|
||||||
|
left.Add(node_hist[j].GetGrad(), node_hist[j].GetHess());
|
||||||
|
right.SetSubstract(GradStats{total_gpair}, left);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Assert that's the best split
|
||||||
|
auto best_loss_chg = entries.front().split.loss_chg;
|
||||||
|
run_eval([&](auto loss_chg) {
|
||||||
|
// Approximated test that gain returned by optimal partition is greater than
|
||||||
|
// numerical split.
|
||||||
|
ASSERT_GT(best_loss_chg, loss_chg);
|
||||||
|
});
|
||||||
|
// node_hist is captured in lambda.
|
||||||
|
std::sort(node_hist.begin(), node_hist.end(), [&](auto l, auto r) {
|
||||||
|
return evaluator.Evaluator().CalcWeightCat(param, l) <
|
||||||
|
evaluator.Evaluator().CalcWeightCat(param, r);
|
||||||
|
});
|
||||||
|
|
||||||
|
double reimpl = 0;
|
||||||
|
run_eval([&](auto loss_chg) { reimpl = std::max(loss_chg, reimpl); });
|
||||||
|
CHECK_EQ(reimpl, best_loss_chg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
auto CompareOneHotAndPartition(bool onehot) {
|
||||||
|
int static constexpr kRows = 128, kCols = 1;
|
||||||
|
using GradientSumT = double;
|
||||||
|
std::vector<FeatureType> ft(kCols, FeatureType::kCategorical);
|
||||||
|
|
||||||
|
TrainParam param;
|
||||||
|
if (onehot) {
|
||||||
|
// force use one-hot
|
||||||
|
param.UpdateAllowUnknown(
|
||||||
|
Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}, {"max_cat_to_onehot", "100"}});
|
||||||
|
} else {
|
||||||
|
param.UpdateAllowUnknown(
|
||||||
|
Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}, {"max_cat_to_onehot", "1"}});
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t n_cats{2};
|
||||||
|
|
||||||
|
auto dmat =
|
||||||
|
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
|
||||||
|
|
||||||
|
int32_t n_threads = 16;
|
||||||
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
|
auto evaluator = HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||||
|
param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}};
|
||||||
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
|
|
||||||
|
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 32})) {
|
||||||
|
common::HistCollection<GradientSumT> hist;
|
||||||
|
|
||||||
|
entries.front().nid = 0;
|
||||||
|
entries.front().depth = 0;
|
||||||
|
|
||||||
|
hist.Init(gmat.cut.TotalBins());
|
||||||
|
hist.AddHistRow(0);
|
||||||
|
hist.AllocateAllData();
|
||||||
|
auto node_hist = hist[0];
|
||||||
|
|
||||||
|
CHECK_EQ(node_hist.size(), n_cats);
|
||||||
|
CHECK_EQ(node_hist.size(), gmat.cut.Ptrs().back());
|
||||||
|
|
||||||
|
GradientPairPrecise total_gpair;
|
||||||
|
for (size_t i = 0; i < node_hist.size(); ++i) {
|
||||||
|
node_hist[i] = {static_cast<double>(node_hist.size() - i), 1.0};
|
||||||
|
total_gpair += node_hist[i];
|
||||||
|
}
|
||||||
|
RegTree tree;
|
||||||
|
evaluator.InitRoot(GradStats{total_gpair});
|
||||||
|
evaluator.EvaluateSplits(hist, gmat.cut, ft, tree, &entries);
|
||||||
|
}
|
||||||
|
return entries.front();
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(HistEvaluator, Categorical) {
|
||||||
|
auto with_onehot = CompareOneHotAndPartition(true);
|
||||||
|
auto with_part = CompareOneHotAndPartition(false);
|
||||||
|
|
||||||
|
ASSERT_EQ(with_onehot.split.loss_chg, with_part.split.loss_chg);
|
||||||
|
}
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -88,14 +88,14 @@ TEST(Param, SplitEntry) {
|
|||||||
|
|
||||||
xgboost::tree::SplitEntry se2;
|
xgboost::tree::SplitEntry se2;
|
||||||
EXPECT_FALSE(se1.Update(se2));
|
EXPECT_FALSE(se1.Update(se2));
|
||||||
EXPECT_FALSE(se2.Update(-1, 100, 0, true, xgboost::tree::GradStats(),
|
EXPECT_FALSE(se2.Update(-1, 100, 0, true, false, xgboost::tree::GradStats(),
|
||||||
xgboost::tree::GradStats()));
|
xgboost::tree::GradStats()));
|
||||||
ASSERT_TRUE(se2.Update(1, 100, 0, true, xgboost::tree::GradStats(),
|
ASSERT_TRUE(se2.Update(1, 100, 0, true, false, xgboost::tree::GradStats(),
|
||||||
xgboost::tree::GradStats()));
|
xgboost::tree::GradStats()));
|
||||||
ASSERT_TRUE(se1.Update(se2));
|
ASSERT_TRUE(se1.Update(se2));
|
||||||
|
|
||||||
xgboost::tree::SplitEntry se3;
|
xgboost::tree::SplitEntry se3;
|
||||||
se3.Update(2, 101, 0, false, xgboost::tree::GradStats(),
|
se3.Update(2, 101, 0, false, false, xgboost::tree::GradStats(),
|
||||||
xgboost::tree::GradStats());
|
xgboost::tree::GradStats());
|
||||||
xgboost::tree::SplitEntry::Reduce(se2, se3);
|
xgboost::tree::SplitEntry::Reduce(se2, se3);
|
||||||
EXPECT_EQ(se2.SplitIndex(), 101);
|
EXPECT_EQ(se2.SplitIndex(), 101);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user