Handle missing categorical value in CPU evaluator. (#7948)
This commit is contained in:
parent
2070afea02
commit
bde4f25794
@ -119,13 +119,90 @@ class HistEvaluator {
|
|||||||
p_best->Update(best);
|
p_best->Update(best);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Enumerate with partition-based splits.
|
||||||
|
*
|
||||||
|
* The implementation is different from LightGBM. Firstly we don't have a
|
||||||
|
* pseudo-cateogry for missing value, instead of we make 2 complete scans over the
|
||||||
|
* histogram. Secondly, both scan directions generate splits in the same
|
||||||
|
* order. Following table depicts the scan process, square bracket means the gradient in
|
||||||
|
* missing values is resided on that partition:
|
||||||
|
*
|
||||||
|
* | Forward | Backward |
|
||||||
|
* |----------+----------|
|
||||||
|
* | [BCDE] A | E [ABCD] |
|
||||||
|
* | [CDE] AB | DE [ABC] |
|
||||||
|
* | [DE] ABC | CDE [AB] |
|
||||||
|
* | [E] ABCD | BCDE [A] |
|
||||||
|
*/
|
||||||
|
template <int d_step>
|
||||||
|
void EnumeratePart(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
|
||||||
|
common::GHistRow const &hist, bst_feature_t fidx, bst_node_t nidx,
|
||||||
|
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
|
||||||
|
SplitEntry *p_best) {
|
||||||
|
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
|
||||||
|
|
||||||
|
auto const &cut_ptr = cut.Ptrs();
|
||||||
|
auto const &parent = snode_[nidx];
|
||||||
|
bst_bin_t n_bins{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
|
||||||
|
|
||||||
|
// statistics on both sides of split
|
||||||
|
GradStats left_sum;
|
||||||
|
GradStats right_sum;
|
||||||
|
// best split so far
|
||||||
|
SplitEntry best;
|
||||||
|
|
||||||
|
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
||||||
|
bst_bin_t ibegin, iend;
|
||||||
|
bst_bin_t f_begin = cut_ptr[fidx];
|
||||||
|
if (d_step > 0) {
|
||||||
|
ibegin = f_begin;
|
||||||
|
iend = ibegin + n_bins - 1;
|
||||||
|
} else {
|
||||||
|
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
|
||||||
|
iend = f_begin;
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_bin_t best_thresh{-1};
|
||||||
|
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
|
||||||
|
auto j = i - f_begin; // index local to current feature
|
||||||
|
if (d_step == 1) {
|
||||||
|
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
|
||||||
|
left_sum.SetSubstract(parent.stats, right_sum); // missing on left
|
||||||
|
} else {
|
||||||
|
left_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
|
||||||
|
right_sum.SetSubstract(parent.stats, left_sum); // missing on right
|
||||||
|
}
|
||||||
|
if (IsValid(left_sum, right_sum)) {
|
||||||
|
auto loss_chg =
|
||||||
|
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
||||||
|
parent.root_gain;
|
||||||
|
// We don't have a numeric split point, nan hare is a dummy split.
|
||||||
|
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
|
||||||
|
left_sum, right_sum)) {
|
||||||
|
best_thresh = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (best_thresh != -1) {
|
||||||
|
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
|
||||||
|
best.cat_bits = decltype(best.cat_bits)(n, 0);
|
||||||
|
common::CatBitField cat_bits{best.cat_bits};
|
||||||
|
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : best_thresh - iend;
|
||||||
|
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
|
||||||
|
[&](size_t c) { cat_bits.Set(c); });
|
||||||
|
}
|
||||||
|
|
||||||
|
p_best->Update(best);
|
||||||
|
}
|
||||||
|
|
||||||
// Enumerate/Scan the split values of specific feature
|
// Enumerate/Scan the split values of specific feature
|
||||||
// Returns the sum of gradients corresponding to the data points that contains
|
// Returns the sum of gradients corresponding to the data points that contains
|
||||||
// a non-missing value for the particular feature fid.
|
// a non-missing value for the particular feature fid.
|
||||||
template <int d_step, SplitType split_type>
|
template <int d_step>
|
||||||
GradStats EnumerateSplit(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
|
GradStats EnumerateSplit(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
|
||||||
const common::GHistRow &hist, bst_feature_t fidx,
|
const common::GHistRow &hist, bst_feature_t fidx, bst_node_t nidx,
|
||||||
bst_node_t nidx,
|
|
||||||
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
|
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
|
||||||
SplitEntry *p_best) const {
|
SplitEntry *p_best) const {
|
||||||
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
|
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
|
||||||
@ -134,8 +211,6 @@ class HistEvaluator {
|
|||||||
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
|
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
|
||||||
const std::vector<bst_float> &cut_val = cut.Values();
|
const std::vector<bst_float> &cut_val = cut.Values();
|
||||||
auto const &parent = snode_[nidx];
|
auto const &parent = snode_[nidx];
|
||||||
int32_t n_bins{static_cast<int32_t>(cut_ptr.at(fidx + 1) - cut_ptr[fidx])};
|
|
||||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
|
||||||
|
|
||||||
// statistics on both sides of split
|
// statistics on both sides of split
|
||||||
GradStats left_sum;
|
GradStats left_sum;
|
||||||
@ -144,50 +219,28 @@ class HistEvaluator {
|
|||||||
SplitEntry best;
|
SplitEntry best;
|
||||||
|
|
||||||
// bin boundaries
|
// bin boundaries
|
||||||
CHECK_LE(cut_ptr[fidx], static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
CHECK_LE(cut_ptr[fidx], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
|
||||||
CHECK_LE(cut_ptr[fidx + 1], static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
CHECK_LE(cut_ptr[fidx + 1], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
|
||||||
// imin: index (offset) of the minimum value for feature fid need this for backward
|
// imin: index (offset) of the minimum value for feature fid need this for backward
|
||||||
// enumeration
|
// enumeration
|
||||||
const auto imin = static_cast<int32_t>(cut_ptr[fidx]);
|
const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);
|
||||||
// ibegin, iend: smallest/largest cut points for feature fid use int to allow for
|
// ibegin, iend: smallest/largest cut points for feature fid use int to allow for
|
||||||
// value -1
|
// value -1
|
||||||
int32_t ibegin, iend;
|
bst_bin_t ibegin, iend;
|
||||||
if (d_step > 0) {
|
if (d_step > 0) {
|
||||||
ibegin = static_cast<int32_t>(cut_ptr[fidx]);
|
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
|
||||||
iend = static_cast<int32_t>(cut_ptr.at(fidx + 1));
|
iend = static_cast<bst_bin_t>(cut_ptr.at(fidx + 1));
|
||||||
} else {
|
} else {
|
||||||
ibegin = static_cast<int32_t>(cut_ptr[fidx + 1]) - 1;
|
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
|
||||||
iend = static_cast<int32_t>(cut_ptr[fidx]) - 1;
|
iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto calc_bin_value = [&](auto i) {
|
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
|
||||||
switch (split_type) {
|
|
||||||
case kNum: {
|
|
||||||
left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
|
|
||||||
right_sum.SetSubstract(parent.stats, left_sum);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case kOneHot: {
|
|
||||||
std::terminate(); // unreachable
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case kPart: {
|
|
||||||
auto j = d_step == 1 ? (i - ibegin) : (ibegin - i);
|
|
||||||
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
|
|
||||||
left_sum.SetSubstract(parent.stats, right_sum);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
int32_t best_thresh{-1};
|
|
||||||
for (int32_t i = ibegin; i != iend; i += d_step) {
|
|
||||||
// start working
|
// start working
|
||||||
// try to find a split
|
// try to find a split
|
||||||
calc_bin_value(i);
|
left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
|
||||||
bool improved{false};
|
right_sum.SetSubstract(parent.stats, left_sum);
|
||||||
if (left_sum.GetHess() >= param_.min_child_weight &&
|
if (IsValid(left_sum, right_sum)) {
|
||||||
right_sum.GetHess() >= param_.min_child_weight) {
|
|
||||||
bst_float loss_chg;
|
bst_float loss_chg;
|
||||||
bst_float split_pt;
|
bst_float split_pt;
|
||||||
if (d_step > 0) {
|
if (d_step > 0) {
|
||||||
@ -197,66 +250,24 @@ class HistEvaluator {
|
|||||||
GradStats{right_sum}) -
|
GradStats{right_sum}) -
|
||||||
parent.root_gain);
|
parent.root_gain);
|
||||||
split_pt = cut_val[i]; // not used for partition based
|
split_pt = cut_val[i]; // not used for partition based
|
||||||
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
|
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
|
||||||
left_sum, right_sum);
|
|
||||||
} else {
|
} else {
|
||||||
// backward enumeration: split at left bound of each bin
|
// backward enumeration: split at left bound of each bin
|
||||||
loss_chg =
|
loss_chg =
|
||||||
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{right_sum},
|
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{right_sum},
|
||||||
GradStats{left_sum}) -
|
GradStats{left_sum}) -
|
||||||
parent.root_gain);
|
parent.root_gain);
|
||||||
switch (split_type) {
|
if (i == imin) {
|
||||||
case kNum: {
|
split_pt = cut.MinValues()[fidx];
|
||||||
if (i == imin) {
|
} else {
|
||||||
split_pt = cut.MinValues()[fidx];
|
split_pt = cut_val[i - 1];
|
||||||
} else {
|
|
||||||
split_pt = cut_val[i - 1];
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case kOneHot: {
|
|
||||||
std::terminate(); // unreachable
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case kPart: {
|
|
||||||
split_pt = cut_val[i];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
|
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
|
||||||
right_sum, left_sum);
|
|
||||||
}
|
|
||||||
if (improved) {
|
|
||||||
best_thresh = i;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (split_type == kPart && best_thresh != -1) {
|
|
||||||
auto n = common::CatBitField::ComputeStorageSize(n_bins);
|
|
||||||
best.cat_bits.resize(n, 0);
|
|
||||||
common::CatBitField cat_bits{best.cat_bits};
|
|
||||||
|
|
||||||
if (d_step == 1) {
|
|
||||||
std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1),
|
|
||||||
[&](size_t c) { cat_bits.Set(cut_val[c + ibegin]); });
|
|
||||||
} else {
|
|
||||||
std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh),
|
|
||||||
[&](size_t c) { cat_bits.Set(cut_val[c + cut_ptr[fidx]]); });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
p_best->Update(best);
|
p_best->Update(best);
|
||||||
|
|
||||||
switch (split_type) {
|
|
||||||
case kNum:
|
|
||||||
// Normal, accumulated to left
|
|
||||||
return left_sum;
|
|
||||||
case kOneHot:
|
|
||||||
return {};
|
|
||||||
case kPart:
|
|
||||||
// Accumulated to right due to chosen cats go to right.
|
|
||||||
return right_sum;
|
|
||||||
}
|
|
||||||
return left_sum;
|
return left_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -316,14 +327,13 @@ class HistEvaluator {
|
|||||||
evaluator.CalcWeightCat(param_, feat_hist[r]);
|
evaluator.CalcWeightCat(param_, feat_hist[r]);
|
||||||
return ret;
|
return ret;
|
||||||
});
|
});
|
||||||
EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||||
EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto grad_stats =
|
auto grad_stats = EnumerateSplit<+1>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||||
EnumerateSplit<+1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
|
||||||
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
|
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
|
||||||
EnumerateSplit<-1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
EnumerateSplit<-1>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -50,12 +50,11 @@ struct CPUExpandEntry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) {
|
friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) {
|
||||||
os << "ExpandEntry: \n";
|
os << "ExpandEntry:\n";
|
||||||
os << "nidx: " << e.nid << "\n";
|
os << "nidx: " << e.nid << "\n";
|
||||||
os << "depth: " << e.depth << "\n";
|
os << "depth: " << e.depth << "\n";
|
||||||
os << "loss: " << e.split.loss_chg << "\n";
|
os << "loss: " << e.split.loss_chg << "\n";
|
||||||
os << "left_sum: " << e.split.left_sum << "\n";
|
os << "split:\n" << e.split << std::endl;
|
||||||
os << "right_sum: " << e.split.right_sum << "\n";
|
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -367,12 +367,14 @@ struct SplitEntryContainer {
|
|||||||
|
|
||||||
SplitEntryContainer() = default;
|
SplitEntryContainer() = default;
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) {
|
friend std::ostream &operator<<(std::ostream &os, SplitEntryContainer const &s) {
|
||||||
os << "loss_chg: " << s.loss_chg << ", "
|
os << "loss_chg: " << s.loss_chg << "\n"
|
||||||
<< "split index: " << s.SplitIndex() << ", "
|
<< "dft_left: " << s.DefaultLeft() << "\n"
|
||||||
<< "split value: " << s.split_value << ", "
|
<< "split_index: " << s.SplitIndex() << "\n"
|
||||||
<< "left_sum: " << s.left_sum << ", "
|
<< "split_value: " << s.split_value << "\n"
|
||||||
<< "right_sum: " << s.right_sum;
|
<< "is_cat: " << s.is_cat << "\n"
|
||||||
|
<< "left_sum: " << s.left_sum << "\n"
|
||||||
|
<< "right_sum: " << s.right_sum << std::endl;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
/*!\return feature index to split on */
|
/*!\return feature index to split on */
|
||||||
@ -446,30 +448,6 @@ struct SplitEntryContainer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Update with partition based categorical split.
|
|
||||||
*
|
|
||||||
* \return Whether the proposed split is better and can replace current split.
|
|
||||||
*/
|
|
||||||
bool Update(float new_loss_chg, bst_feature_t split_index, common::KCatBitField cats,
|
|
||||||
bool default_left, GradientT const &left_sum, GradientT const &right_sum) {
|
|
||||||
if (this->NeedReplace(new_loss_chg, split_index)) {
|
|
||||||
this->loss_chg = new_loss_chg;
|
|
||||||
if (default_left) {
|
|
||||||
split_index |= (1U << 31);
|
|
||||||
}
|
|
||||||
this->sindex = split_index;
|
|
||||||
cat_bits.resize(cats.Bits().size());
|
|
||||||
std::copy(cats.Bits().begin(), cats.Bits().end(), cat_bits.begin());
|
|
||||||
this->is_cat = true;
|
|
||||||
this->left_sum = left_sum;
|
|
||||||
this->right_sum = right_sum;
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief same as update, used by AllReduce*/
|
/*! \brief same as update, used by AllReduce*/
|
||||||
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
|
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
|
||||||
const SplitEntryContainer &src) { // NOLINT(*)
|
const SplitEntryContainer &src) { // NOLINT(*)
|
||||||
|
|||||||
@ -147,9 +147,9 @@ auto CompareOneHotAndPartition(bool onehot) {
|
|||||||
auto dmat =
|
auto dmat =
|
||||||
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
|
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
|
||||||
|
|
||||||
int32_t n_threads = 16;
|
|
||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
auto evaluator = HistEvaluator<CPUExpandEntry>{param, dmat->Info(), n_threads, sampler};
|
auto evaluator =
|
||||||
|
HistEvaluator<CPUExpandEntry>{param, dmat->Info(), common::OmpGetNumThreads(0), sampler};
|
||||||
std::vector<CPUExpandEntry> entries(1);
|
std::vector<CPUExpandEntry> entries(1);
|
||||||
|
|
||||||
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
|
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
|
||||||
|
|||||||
@ -2,11 +2,14 @@
|
|||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
#include <algorithm> // next_permutation
|
#include <algorithm> // next_permutation
|
||||||
#include <numeric> // iota
|
#include <numeric> // iota
|
||||||
|
|
||||||
#include "../../../src/tree/hist/evaluate_splits.h"
|
#include "../../../src/common/hist_util.h" // HistogramCuts,HistCollection
|
||||||
|
#include "../../../src/tree/param.h" // TrainParam
|
||||||
|
#include "../../../src/tree/split_evaluator.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|||||||
@ -77,6 +77,16 @@ class TestGPUUpdaters:
|
|||||||
def test_categorical(self, rows, cols, rounds, cats):
|
def test_categorical(self, rows, cols, rounds, cats):
|
||||||
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
|
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
|
||||||
|
|
||||||
|
@given(
|
||||||
|
strategies.integers(10, 400),
|
||||||
|
strategies.integers(3, 8),
|
||||||
|
strategies.integers(4, 7)
|
||||||
|
)
|
||||||
|
@settings(deadline=None, print_blob=True)
|
||||||
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
|
def test_categorical_missing(self, rows, cols, cats):
|
||||||
|
self.cputest.run_categorical_missing(rows, cols, cats, "gpu_hist")
|
||||||
|
|
||||||
def test_max_cat(self) -> None:
|
def test_max_cat(self) -> None:
|
||||||
self.cputest.run_max_cat("gpu_hist")
|
self.cputest.run_max_cat("gpu_hist")
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from random import choice
|
from random import choice
|
||||||
from string import ascii_lowercase
|
from string import ascii_lowercase
|
||||||
|
from typing import Dict, Any
|
||||||
import testing as tm
|
import testing as tm
|
||||||
import pytest
|
import pytest
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
@ -38,6 +39,9 @@ def train_result(param, dmat, num_rounds):
|
|||||||
|
|
||||||
|
|
||||||
class TestTreeMethod:
|
class TestTreeMethod:
|
||||||
|
USE_ONEHOT = np.iinfo(np.int32).max
|
||||||
|
USE_PART = 1
|
||||||
|
|
||||||
@given(exact_parameter_strategy, strategies.integers(1, 20),
|
@given(exact_parameter_strategy, strategies.integers(1, 20),
|
||||||
tm.dataset_strategy)
|
tm.dataset_strategy)
|
||||||
@settings(deadline=None, print_blob=True)
|
@settings(deadline=None, print_blob=True)
|
||||||
@ -213,10 +217,43 @@ class TestTreeMethod:
|
|||||||
def test_max_cat(self, tree_method) -> None:
|
def test_max_cat(self, tree_method) -> None:
|
||||||
self.run_max_cat(tree_method)
|
self.run_max_cat(tree_method)
|
||||||
|
|
||||||
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
def run_categorical_missing(
|
||||||
USE_ONEHOT = np.iinfo(np.int32).max
|
self, rows: int, cols: int, cats: int, tree_method: str
|
||||||
USE_PART = 1
|
) -> None:
|
||||||
|
parameters: Dict[str, Any] = {"tree_method": tree_method}
|
||||||
|
cat, label = tm.make_categorical(
|
||||||
|
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
|
||||||
|
)
|
||||||
|
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||||
|
|
||||||
|
def run(max_cat_to_onehot: int):
|
||||||
|
# Test with onehot splits
|
||||||
|
parameters["max_cat_to_onehot"] = max_cat_to_onehot
|
||||||
|
|
||||||
|
evals_result: Dict[str, Dict] = {}
|
||||||
|
booster = xgb.train(
|
||||||
|
parameters,
|
||||||
|
Xy,
|
||||||
|
num_boost_round=16,
|
||||||
|
evals=[(Xy, "Train")],
|
||||||
|
evals_result=evals_result
|
||||||
|
)
|
||||||
|
assert tm.non_increasing(evals_result["Train"]["rmse"])
|
||||||
|
y_predt = booster.predict(Xy)
|
||||||
|
|
||||||
|
rmse = tm.root_mean_square(label, y_predt)
|
||||||
|
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])
|
||||||
|
|
||||||
|
# Test with OHE split
|
||||||
|
run(self.USE_ONEHOT)
|
||||||
|
|
||||||
|
if tree_method == "gpu_hist": # fixme: Test with GPU.
|
||||||
|
return
|
||||||
|
|
||||||
|
# Test with partition-based split
|
||||||
|
run(self.USE_PART)
|
||||||
|
|
||||||
|
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
||||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||||
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
||||||
|
|
||||||
@ -226,7 +263,7 @@ class TestTreeMethod:
|
|||||||
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
|
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
|
||||||
parameters = {"tree_method": tree_method, "predictor": predictor}
|
parameters = {"tree_method": tree_method, "predictor": predictor}
|
||||||
# Use one-hot exclusively
|
# Use one-hot exclusively
|
||||||
parameters["max_cat_to_onehot"] = USE_ONEHOT
|
parameters["max_cat_to_onehot"] = self.USE_ONEHOT
|
||||||
|
|
||||||
m = xgb.DMatrix(onehot, label, enable_categorical=False)
|
m = xgb.DMatrix(onehot, label, enable_categorical=False)
|
||||||
xgb.train(
|
xgb.train(
|
||||||
@ -260,7 +297,7 @@ class TestTreeMethod:
|
|||||||
|
|
||||||
by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
|
by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
|
||||||
# switch to partition-based splits
|
# switch to partition-based splits
|
||||||
parameters["max_cat_to_onehot"] = USE_PART
|
parameters["max_cat_to_onehot"] = self.USE_PART
|
||||||
parameters["reg_lambda"] = 0
|
parameters["reg_lambda"] = 0
|
||||||
m = xgb.DMatrix(cat, label, enable_categorical=True)
|
m = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||||
xgb.train(
|
xgb.train(
|
||||||
@ -287,27 +324,6 @@ class TestTreeMethod:
|
|||||||
)
|
)
|
||||||
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
|
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
|
||||||
|
|
||||||
# test with missing values
|
|
||||||
cat, label = tm.make_categorical(
|
|
||||||
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
|
|
||||||
)
|
|
||||||
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
|
|
||||||
evals_result = {}
|
|
||||||
# Test with onehot splits
|
|
||||||
parameters["max_cat_to_onehot"] = USE_ONEHOT
|
|
||||||
booster = xgb.train(
|
|
||||||
parameters,
|
|
||||||
Xy,
|
|
||||||
num_boost_round=16,
|
|
||||||
evals=[(Xy, "Train")],
|
|
||||||
evals_result=evals_result
|
|
||||||
)
|
|
||||||
assert tm.non_increasing(evals_result["Train"]["rmse"])
|
|
||||||
y_predt = booster.predict(Xy)
|
|
||||||
|
|
||||||
rmse = tm.root_mean_square(label, y_predt)
|
|
||||||
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])
|
|
||||||
|
|
||||||
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
||||||
strategies.integers(1, 2), strategies.integers(4, 7))
|
strategies.integers(1, 2), strategies.integers(4, 7))
|
||||||
@settings(deadline=None, print_blob=True)
|
@settings(deadline=None, print_blob=True)
|
||||||
@ -315,3 +331,14 @@ class TestTreeMethod:
|
|||||||
def test_categorical(self, rows, cols, rounds, cats):
|
def test_categorical(self, rows, cols, rounds, cats):
|
||||||
self.run_categorical_basic(rows, cols, rounds, cats, "approx")
|
self.run_categorical_basic(rows, cols, rounds, cats, "approx")
|
||||||
self.run_categorical_basic(rows, cols, rounds, cats, "hist")
|
self.run_categorical_basic(rows, cols, rounds, cats, "hist")
|
||||||
|
|
||||||
|
@given(
|
||||||
|
strategies.integers(10, 400),
|
||||||
|
strategies.integers(3, 8),
|
||||||
|
strategies.integers(4, 7)
|
||||||
|
)
|
||||||
|
@settings(deadline=None, print_blob=True)
|
||||||
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
|
def test_categorical_missing(self, rows, cols, cats):
|
||||||
|
self.run_categorical_missing(rows, cols, cats, "approx")
|
||||||
|
self.run_categorical_missing(rows, cols, cats, "hist")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user