Handle missing categorical value in CPU evaluator. (#7948)

This commit is contained in:
Jiaming Yuan 2022-05-27 14:15:47 +08:00 committed by GitHub
parent 2070afea02
commit bde4f25794
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 181 additions and 154 deletions

View File

@ -119,13 +119,90 @@ class HistEvaluator {
p_best->Update(best);
}
/**
* \brief Enumerate with partition-based splits.
*
* The implementation is different from LightGBM. Firstly we don't have a
* pseudo-cateogry for missing value, instead of we make 2 complete scans over the
* histogram. Secondly, both scan directions generate splits in the same
* order. Following table depicts the scan process, square bracket means the gradient in
* missing values is resided on that partition:
*
* | Forward | Backward |
* |----------+----------|
* | [BCDE] A | E [ABCD] |
* | [CDE] AB | DE [ABC] |
* | [DE] ABC | CDE [AB] |
* | [E] ABCD | BCDE [A] |
*/
template <int d_step>
void EnumeratePart(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
common::GHistRow const &hist, bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) {
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
auto const &cut_ptr = cut.Ptrs();
auto const &parent = snode_[nidx];
bst_bin_t n_bins{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
// statistics on both sides of split
GradStats left_sum;
GradStats right_sum;
// best split so far
SplitEntry best;
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
bst_bin_t ibegin, iend;
bst_bin_t f_begin = cut_ptr[fidx];
if (d_step > 0) {
ibegin = f_begin;
iend = ibegin + n_bins - 1;
} else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = f_begin;
}
bst_bin_t best_thresh{-1};
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
auto j = i - f_begin; // index local to current feature
if (d_step == 1) {
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
left_sum.SetSubstract(parent.stats, right_sum); // missing on left
} else {
left_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
right_sum.SetSubstract(parent.stats, left_sum); // missing on right
}
if (IsValid(left_sum, right_sum)) {
auto loss_chg =
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain;
// We don't have a numeric split point, nan hare is a dummy split.
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
left_sum, right_sum)) {
best_thresh = i;
}
}
}
if (best_thresh != -1) {
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
best.cat_bits = decltype(best.cat_bits)(n, 0);
common::CatBitField cat_bits{best.cat_bits};
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : best_thresh - iend;
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
[&](size_t c) { cat_bits.Set(c); });
}
p_best->Update(best);
}
// Enumerate/Scan the split values of specific feature
// Returns the sum of gradients corresponding to the data points that contains
// a non-missing value for the particular feature fid.
template <int d_step, SplitType split_type>
template <int d_step>
GradStats EnumerateSplit(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
const common::GHistRow &hist, bst_feature_t fidx,
bst_node_t nidx,
const common::GHistRow &hist, bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) const {
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
@ -134,8 +211,6 @@ class HistEvaluator {
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
const std::vector<bst_float> &cut_val = cut.Values();
auto const &parent = snode_[nidx];
int32_t n_bins{static_cast<int32_t>(cut_ptr.at(fidx + 1) - cut_ptr[fidx])};
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
// statistics on both sides of split
GradStats left_sum;
@ -144,50 +219,28 @@ class HistEvaluator {
SplitEntry best;
// bin boundaries
CHECK_LE(cut_ptr[fidx], static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
CHECK_LE(cut_ptr[fidx + 1], static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
CHECK_LE(cut_ptr[fidx], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
CHECK_LE(cut_ptr[fidx + 1], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
// imin: index (offset) of the minimum value for feature fid need this for backward
// enumeration
const auto imin = static_cast<int32_t>(cut_ptr[fidx]);
const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);
// ibegin, iend: smallest/largest cut points for feature fid use int to allow for
// value -1
int32_t ibegin, iend;
bst_bin_t ibegin, iend;
if (d_step > 0) {
ibegin = static_cast<int32_t>(cut_ptr[fidx]);
iend = static_cast<int32_t>(cut_ptr.at(fidx + 1));
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
iend = static_cast<bst_bin_t>(cut_ptr.at(fidx + 1));
} else {
ibegin = static_cast<int32_t>(cut_ptr[fidx + 1]) - 1;
iend = static_cast<int32_t>(cut_ptr[fidx]) - 1;
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
}
auto calc_bin_value = [&](auto i) {
switch (split_type) {
case kNum: {
left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
right_sum.SetSubstract(parent.stats, left_sum);
break;
}
case kOneHot: {
std::terminate(); // unreachable
break;
}
case kPart: {
auto j = d_step == 1 ? (i - ibegin) : (ibegin - i);
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
left_sum.SetSubstract(parent.stats, right_sum);
break;
}
}
};
int32_t best_thresh{-1};
for (int32_t i = ibegin; i != iend; i += d_step) {
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
// start working
// try to find a split
calc_bin_value(i);
bool improved{false};
if (left_sum.GetHess() >= param_.min_child_weight &&
right_sum.GetHess() >= param_.min_child_weight) {
left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
right_sum.SetSubstract(parent.stats, left_sum);
if (IsValid(left_sum, right_sum)) {
bst_float loss_chg;
bst_float split_pt;
if (d_step > 0) {
@ -197,66 +250,24 @@ class HistEvaluator {
GradStats{right_sum}) -
parent.root_gain);
split_pt = cut_val[i]; // not used for partition based
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
left_sum, right_sum);
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
} else {
// backward enumeration: split at left bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{right_sum},
GradStats{left_sum}) -
parent.root_gain);
switch (split_type) {
case kNum: {
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
}
break;
}
case kOneHot: {
std::terminate(); // unreachable
break;
}
case kPart: {
split_pt = cut_val[i];
break;
}
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
}
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
right_sum, left_sum);
}
if (improved) {
best_thresh = i;
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
}
}
if (split_type == kPart && best_thresh != -1) {
auto n = common::CatBitField::ComputeStorageSize(n_bins);
best.cat_bits.resize(n, 0);
common::CatBitField cat_bits{best.cat_bits};
if (d_step == 1) {
std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1),
[&](size_t c) { cat_bits.Set(cut_val[c + ibegin]); });
} else {
std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh),
[&](size_t c) { cat_bits.Set(cut_val[c + cut_ptr[fidx]]); });
}
}
p_best->Update(best);
switch (split_type) {
case kNum:
// Normal, accumulated to left
return left_sum;
case kOneHot:
return {};
case kPart:
// Accumulated to right due to chosen cats go to right.
return right_sum;
}
return left_sum;
}
@ -316,14 +327,13 @@ class HistEvaluator {
evaluator.CalcWeightCat(param_, feat_hist[r]);
return ret;
});
EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
}
} else {
auto grad_stats =
EnumerateSplit<+1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best);
auto grad_stats = EnumerateSplit<+1>(cut, {}, histogram, fidx, nidx, evaluator, best);
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<-1>(cut, {}, histogram, fidx, nidx, evaluator, best);
}
}
}

View File

@ -50,12 +50,11 @@ struct CPUExpandEntry {
}
friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) {
os << "ExpandEntry: \n";
os << "ExpandEntry:\n";
os << "nidx: " << e.nid << "\n";
os << "depth: " << e.depth << "\n";
os << "loss: " << e.split.loss_chg << "\n";
os << "left_sum: " << e.split.left_sum << "\n";
os << "right_sum: " << e.split.right_sum << "\n";
os << "split:\n" << e.split << std::endl;
return os;
}
};

View File

@ -367,12 +367,14 @@ struct SplitEntryContainer {
SplitEntryContainer() = default;
friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) {
os << "loss_chg: " << s.loss_chg << ", "
<< "split index: " << s.SplitIndex() << ", "
<< "split value: " << s.split_value << ", "
<< "left_sum: " << s.left_sum << ", "
<< "right_sum: " << s.right_sum;
friend std::ostream &operator<<(std::ostream &os, SplitEntryContainer const &s) {
os << "loss_chg: " << s.loss_chg << "\n"
<< "dft_left: " << s.DefaultLeft() << "\n"
<< "split_index: " << s.SplitIndex() << "\n"
<< "split_value: " << s.split_value << "\n"
<< "is_cat: " << s.is_cat << "\n"
<< "left_sum: " << s.left_sum << "\n"
<< "right_sum: " << s.right_sum << std::endl;
return os;
}
/*!\return feature index to split on */
@ -446,30 +448,6 @@ struct SplitEntryContainer {
}
}
/*!
* \brief Update with partition based categorical split.
*
* \return Whether the proposed split is better and can replace current split.
*/
bool Update(float new_loss_chg, bst_feature_t split_index, common::KCatBitField cats,
bool default_left, GradientT const &left_sum, GradientT const &right_sum) {
if (this->NeedReplace(new_loss_chg, split_index)) {
this->loss_chg = new_loss_chg;
if (default_left) {
split_index |= (1U << 31);
}
this->sindex = split_index;
cat_bits.resize(cats.Bits().size());
std::copy(cats.Bits().begin(), cats.Bits().end(), cat_bits.begin());
this->is_cat = true;
this->left_sum = left_sum;
this->right_sum = right_sum;
return true;
} else {
return false;
}
}
/*! \brief same as update, used by AllReduce*/
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
const SplitEntryContainer &src) { // NOLINT(*)

View File

@ -147,9 +147,9 @@ auto CompareOneHotAndPartition(bool onehot) {
auto dmat =
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
int32_t n_threads = 16;
auto sampler = std::make_shared<common::ColumnSampler>();
auto evaluator = HistEvaluator<CPUExpandEntry>{param, dmat->Info(), n_threads, sampler};
auto evaluator =
HistEvaluator<CPUExpandEntry>{param, dmat->Info(), common::OmpGetNumThreads(0), sampler};
std::vector<CPUExpandEntry> entries(1);
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {

View File

@ -2,11 +2,14 @@
* Copyright 2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <algorithm> // next_permutation
#include <numeric> // iota
#include "../../../src/tree/hist/evaluate_splits.h"
#include "../../../src/common/hist_util.h" // HistogramCuts,HistCollection
#include "../../../src/tree/param.h" // TrainParam
#include "../../../src/tree/split_evaluator.h"
#include "../helpers.h"
namespace xgboost {

View File

@ -77,6 +77,16 @@ class TestGPUUpdaters:
def test_categorical(self, rows, cols, rounds, cats):
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
@given(
strategies.integers(10, 400),
strategies.integers(3, 8),
strategies.integers(4, 7)
)
@settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical_missing(self, rows, cols, cats):
self.cputest.run_categorical_missing(rows, cols, cats, "gpu_hist")
def test_max_cat(self) -> None:
self.cputest.run_max_cat("gpu_hist")

View File

@ -1,5 +1,6 @@
from random import choice
from string import ascii_lowercase
from typing import Dict, Any
import testing as tm
import pytest
import xgboost as xgb
@ -38,6 +39,9 @@ def train_result(param, dmat, num_rounds):
class TestTreeMethod:
USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1
@given(exact_parameter_strategy, strategies.integers(1, 20),
tm.dataset_strategy)
@settings(deadline=None, print_blob=True)
@ -213,10 +217,43 @@ class TestTreeMethod:
def test_max_cat(self, tree_method) -> None:
self.run_max_cat(tree_method)
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1
def run_categorical_missing(
self, rows: int, cols: int, cats: int, tree_method: str
) -> None:
parameters: Dict[str, Any] = {"tree_method": tree_method}
cat, label = tm.make_categorical(
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
def run(max_cat_to_onehot: int):
# Test with onehot splits
parameters["max_cat_to_onehot"] = max_cat_to_onehot
evals_result: Dict[str, Dict] = {}
booster = xgb.train(
parameters,
Xy,
num_boost_round=16,
evals=[(Xy, "Train")],
evals_result=evals_result
)
assert tm.non_increasing(evals_result["Train"]["rmse"])
y_predt = booster.predict(Xy)
rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])
# Test with OHE split
run(self.USE_ONEHOT)
if tree_method == "gpu_hist": # fixme: Test with GPU.
return
# Test with partition-based split
run(self.USE_PART)
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
@ -226,7 +263,7 @@ class TestTreeMethod:
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
parameters = {"tree_method": tree_method, "predictor": predictor}
# Use one-hot exclusively
parameters["max_cat_to_onehot"] = USE_ONEHOT
parameters["max_cat_to_onehot"] = self.USE_ONEHOT
m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
@ -260,7 +297,7 @@ class TestTreeMethod:
by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
# switch to partition-based splits
parameters["max_cat_to_onehot"] = USE_PART
parameters["max_cat_to_onehot"] = self.USE_PART
parameters["reg_lambda"] = 0
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
@ -287,27 +324,6 @@ class TestTreeMethod:
)
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
# test with missing values
cat, label = tm.make_categorical(
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
evals_result = {}
# Test with onehot splits
parameters["max_cat_to_onehot"] = USE_ONEHOT
booster = xgb.train(
parameters,
Xy,
num_boost_round=16,
evals=[(Xy, "Train")],
evals_result=evals_result
)
assert tm.non_increasing(evals_result["Train"]["rmse"])
y_predt = booster.predict(Xy)
rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])
@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True)
@ -315,3 +331,14 @@ class TestTreeMethod:
def test_categorical(self, rows, cols, rounds, cats):
self.run_categorical_basic(rows, cols, rounds, cats, "approx")
self.run_categorical_basic(rows, cols, rounds, cats, "hist")
@given(
strategies.integers(10, 400),
strategies.integers(3, 8),
strategies.integers(4, 7)
)
@settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical_missing(self, rows, cols, cats):
self.run_categorical_missing(rows, cols, cats, "approx")
self.run_categorical_missing(rows, cols, cats, "hist")