From bde4f257941b3187fadfce76e945fdbf7cf3b161 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 27 May 2022 14:15:47 +0800 Subject: [PATCH] Handle missing categorical value in CPU evaluator. (#7948) --- src/tree/hist/evaluate_splits.h | 194 ++++++++++---------- src/tree/hist/expand_entry.h | 5 +- src/tree/param.h | 38 +--- tests/cpp/tree/hist/test_evaluate_splits.cc | 4 +- tests/cpp/tree/test_evaluate_splits.h | 5 +- tests/python-gpu/test_gpu_updaters.py | 10 + tests/python/test_updaters.py | 79 +++++--- 7 files changed, 181 insertions(+), 154 deletions(-) diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 7fbd27d56..312012682 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -119,13 +119,90 @@ class HistEvaluator { p_best->Update(best); } + /** + * \brief Enumerate with partition-based splits. + * + * The implementation is different from LightGBM. Firstly we don't have a + * pseudo-cateogry for missing value, instead of we make 2 complete scans over the + * histogram. Secondly, both scan directions generate splits in the same + * order. Following table depicts the scan process, square bracket means the gradient in + * missing values is resided on that partition: + * + * | Forward | Backward | + * |----------+----------| + * | [BCDE] A | E [ABCD] | + * | [CDE] AB | DE [ABC] | + * | [DE] ABC | CDE [AB] | + * | [E] ABCD | BCDE [A] | + */ + template + void EnumeratePart(common::HistogramCuts const &cut, common::Span sorted_idx, + common::GHistRow const &hist, bst_feature_t fidx, bst_node_t nidx, + TreeEvaluator::SplitEvaluator const &evaluator, + SplitEntry *p_best) { + static_assert(d_step == +1 || d_step == -1, "Invalid step."); + + auto const &cut_ptr = cut.Ptrs(); + auto const &parent = snode_[nidx]; + bst_bin_t n_bins{static_cast(cut_ptr[fidx + 1] - cut_ptr[fidx])}; + + // statistics on both sides of split + GradStats left_sum; + GradStats right_sum; + // best split so far + SplitEntry best; + + auto f_hist = hist.subspan(cut_ptr[fidx], n_bins); + bst_bin_t ibegin, iend; + bst_bin_t f_begin = cut_ptr[fidx]; + if (d_step > 0) { + ibegin = f_begin; + iend = ibegin + n_bins - 1; + } else { + ibegin = static_cast(cut_ptr[fidx + 1]) - 1; + iend = f_begin; + } + + bst_bin_t best_thresh{-1}; + for (bst_bin_t i = ibegin; i != iend; i += d_step) { + auto j = i - f_begin; // index local to current feature + if (d_step == 1) { + right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess()); + left_sum.SetSubstract(parent.stats, right_sum); // missing on left + } else { + left_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess()); + right_sum.SetSubstract(parent.stats, left_sum); // missing on right + } + if (IsValid(left_sum, right_sum)) { + auto loss_chg = + evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) - + parent.root_gain; + // We don't have a numeric split point, nan hare is a dummy split. + if (best.Update(loss_chg, fidx, std::numeric_limits::quiet_NaN(), d_step == 1, true, + left_sum, right_sum)) { + best_thresh = i; + } + } + } + + if (best_thresh != -1) { + auto n = common::CatBitField::ComputeStorageSize(n_bins + 1); + best.cat_bits = decltype(best.cat_bits)(n, 0); + common::CatBitField cat_bits{best.cat_bits}; + bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : best_thresh - iend; + std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, + [&](size_t c) { cat_bits.Set(c); }); + } + + p_best->Update(best); + } + // Enumerate/Scan the split values of specific feature // Returns the sum of gradients corresponding to the data points that contains // a non-missing value for the particular feature fid. - template + template GradStats EnumerateSplit(common::HistogramCuts const &cut, common::Span sorted_idx, - const common::GHistRow &hist, bst_feature_t fidx, - bst_node_t nidx, + const common::GHistRow &hist, bst_feature_t fidx, bst_node_t nidx, TreeEvaluator::SplitEvaluator const &evaluator, SplitEntry *p_best) const { static_assert(d_step == +1 || d_step == -1, "Invalid step."); @@ -134,8 +211,6 @@ class HistEvaluator { const std::vector &cut_ptr = cut.Ptrs(); const std::vector &cut_val = cut.Values(); auto const &parent = snode_[nidx]; - int32_t n_bins{static_cast(cut_ptr.at(fidx + 1) - cut_ptr[fidx])}; - auto f_hist = hist.subspan(cut_ptr[fidx], n_bins); // statistics on both sides of split GradStats left_sum; @@ -144,50 +219,28 @@ class HistEvaluator { SplitEntry best; // bin boundaries - CHECK_LE(cut_ptr[fidx], static_cast(std::numeric_limits::max())); - CHECK_LE(cut_ptr[fidx + 1], static_cast(std::numeric_limits::max())); + CHECK_LE(cut_ptr[fidx], static_cast(std::numeric_limits::max())); + CHECK_LE(cut_ptr[fidx + 1], static_cast(std::numeric_limits::max())); // imin: index (offset) of the minimum value for feature fid need this for backward // enumeration - const auto imin = static_cast(cut_ptr[fidx]); + const auto imin = static_cast(cut_ptr[fidx]); // ibegin, iend: smallest/largest cut points for feature fid use int to allow for // value -1 - int32_t ibegin, iend; + bst_bin_t ibegin, iend; if (d_step > 0) { - ibegin = static_cast(cut_ptr[fidx]); - iend = static_cast(cut_ptr.at(fidx + 1)); + ibegin = static_cast(cut_ptr[fidx]); + iend = static_cast(cut_ptr.at(fidx + 1)); } else { - ibegin = static_cast(cut_ptr[fidx + 1]) - 1; - iend = static_cast(cut_ptr[fidx]) - 1; + ibegin = static_cast(cut_ptr[fidx + 1]) - 1; + iend = static_cast(cut_ptr[fidx]) - 1; } - auto calc_bin_value = [&](auto i) { - switch (split_type) { - case kNum: { - left_sum.Add(hist[i].GetGrad(), hist[i].GetHess()); - right_sum.SetSubstract(parent.stats, left_sum); - break; - } - case kOneHot: { - std::terminate(); // unreachable - break; - } - case kPart: { - auto j = d_step == 1 ? (i - ibegin) : (ibegin - i); - right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess()); - left_sum.SetSubstract(parent.stats, right_sum); - break; - } - } - }; - - int32_t best_thresh{-1}; - for (int32_t i = ibegin; i != iend; i += d_step) { + for (bst_bin_t i = ibegin; i != iend; i += d_step) { // start working // try to find a split - calc_bin_value(i); - bool improved{false}; - if (left_sum.GetHess() >= param_.min_child_weight && - right_sum.GetHess() >= param_.min_child_weight) { + left_sum.Add(hist[i].GetGrad(), hist[i].GetHess()); + right_sum.SetSubstract(parent.stats, left_sum); + if (IsValid(left_sum, right_sum)) { bst_float loss_chg; bst_float split_pt; if (d_step > 0) { @@ -197,66 +250,24 @@ class HistEvaluator { GradStats{right_sum}) - parent.root_gain); split_pt = cut_val[i]; // not used for partition based - improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum, - left_sum, right_sum); + best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum); } else { // backward enumeration: split at left bound of each bin loss_chg = static_cast(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{right_sum}, GradStats{left_sum}) - parent.root_gain); - switch (split_type) { - case kNum: { - if (i == imin) { - split_pt = cut.MinValues()[fidx]; - } else { - split_pt = cut_val[i - 1]; - } - break; - } - case kOneHot: { - std::terminate(); // unreachable - break; - } - case kPart: { - split_pt = cut_val[i]; - break; - } + if (i == imin) { + split_pt = cut.MinValues()[fidx]; + } else { + split_pt = cut_val[i - 1]; } - improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum, - right_sum, left_sum); - } - if (improved) { - best_thresh = i; + best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum); } } } - if (split_type == kPart && best_thresh != -1) { - auto n = common::CatBitField::ComputeStorageSize(n_bins); - best.cat_bits.resize(n, 0); - common::CatBitField cat_bits{best.cat_bits}; - - if (d_step == 1) { - std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1), - [&](size_t c) { cat_bits.Set(cut_val[c + ibegin]); }); - } else { - std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh), - [&](size_t c) { cat_bits.Set(cut_val[c + cut_ptr[fidx]]); }); - } - } p_best->Update(best); - - switch (split_type) { - case kNum: - // Normal, accumulated to left - return left_sum; - case kOneHot: - return {}; - case kPart: - // Accumulated to right due to chosen cats go to right. - return right_sum; - } return left_sum; } @@ -316,14 +327,13 @@ class HistEvaluator { evaluator.CalcWeightCat(param_, feat_hist[r]); return ret; }); - EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); - EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); + EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); + EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); } } else { - auto grad_stats = - EnumerateSplit<+1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best); + auto grad_stats = EnumerateSplit<+1>(cut, {}, histogram, fidx, nidx, evaluator, best); if (SplitContainsMissingValues(grad_stats, snode_[nidx])) { - EnumerateSplit<-1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best); + EnumerateSplit<-1>(cut, {}, histogram, fidx, nidx, evaluator, best); } } } diff --git a/src/tree/hist/expand_entry.h b/src/tree/hist/expand_entry.h index d0edfbd37..885a109bf 100644 --- a/src/tree/hist/expand_entry.h +++ b/src/tree/hist/expand_entry.h @@ -50,12 +50,11 @@ struct CPUExpandEntry { } friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) { - os << "ExpandEntry: \n"; + os << "ExpandEntry:\n"; os << "nidx: " << e.nid << "\n"; os << "depth: " << e.depth << "\n"; os << "loss: " << e.split.loss_chg << "\n"; - os << "left_sum: " << e.split.left_sum << "\n"; - os << "right_sum: " << e.split.right_sum << "\n"; + os << "split:\n" << e.split << std::endl; return os; } }; diff --git a/src/tree/param.h b/src/tree/param.h index 7ed796a1e..ab9e23098 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -367,12 +367,14 @@ struct SplitEntryContainer { SplitEntryContainer() = default; - friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) { - os << "loss_chg: " << s.loss_chg << ", " - << "split index: " << s.SplitIndex() << ", " - << "split value: " << s.split_value << ", " - << "left_sum: " << s.left_sum << ", " - << "right_sum: " << s.right_sum; + friend std::ostream &operator<<(std::ostream &os, SplitEntryContainer const &s) { + os << "loss_chg: " << s.loss_chg << "\n" + << "dft_left: " << s.DefaultLeft() << "\n" + << "split_index: " << s.SplitIndex() << "\n" + << "split_value: " << s.split_value << "\n" + << "is_cat: " << s.is_cat << "\n" + << "left_sum: " << s.left_sum << "\n" + << "right_sum: " << s.right_sum << std::endl; return os; } /*!\return feature index to split on */ @@ -446,30 +448,6 @@ struct SplitEntryContainer { } } - /*! - * \brief Update with partition based categorical split. - * - * \return Whether the proposed split is better and can replace current split. - */ - bool Update(float new_loss_chg, bst_feature_t split_index, common::KCatBitField cats, - bool default_left, GradientT const &left_sum, GradientT const &right_sum) { - if (this->NeedReplace(new_loss_chg, split_index)) { - this->loss_chg = new_loss_chg; - if (default_left) { - split_index |= (1U << 31); - } - this->sindex = split_index; - cat_bits.resize(cats.Bits().size()); - std::copy(cats.Bits().begin(), cats.Bits().end(), cat_bits.begin()); - this->is_cat = true; - this->left_sum = left_sum; - this->right_sum = right_sum; - return true; - } else { - return false; - } - } - /*! \brief same as update, used by AllReduce*/ inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*) const SplitEntryContainer &src) { // NOLINT(*) diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 493535aab..0dede2706 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -147,9 +147,9 @@ auto CompareOneHotAndPartition(bool onehot) { auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix(); - int32_t n_threads = 16; auto sampler = std::make_shared(); - auto evaluator = HistEvaluator{param, dmat->Info(), n_threads, sampler}; + auto evaluator = + HistEvaluator{param, dmat->Info(), common::OmpGetNumThreads(0), sampler}; std::vector entries(1); for (auto const &gmat : dmat->GetBatches({32, param.sparse_threshold})) { diff --git a/tests/cpp/tree/test_evaluate_splits.h b/tests/cpp/tree/test_evaluate_splits.h index c8e0f577e..bbd8b98eb 100644 --- a/tests/cpp/tree/test_evaluate_splits.h +++ b/tests/cpp/tree/test_evaluate_splits.h @@ -2,11 +2,14 @@ * Copyright 2022 by XGBoost Contributors */ #include +#include #include // next_permutation #include // iota -#include "../../../src/tree/hist/evaluate_splits.h" +#include "../../../src/common/hist_util.h" // HistogramCuts,HistCollection +#include "../../../src/tree/param.h" // TrainParam +#include "../../../src/tree/split_evaluator.h" #include "../helpers.h" namespace xgboost { diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 860b45929..fd27e3771 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -77,6 +77,16 @@ class TestGPUUpdaters: def test_categorical(self, rows, cols, rounds, cats): self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") + @given( + strategies.integers(10, 400), + strategies.integers(3, 8), + strategies.integers(4, 7) + ) + @settings(deadline=None, print_blob=True) + @pytest.mark.skipif(**tm.no_pandas()) + def test_categorical_missing(self, rows, cols, cats): + self.cputest.run_categorical_missing(rows, cols, cats, "gpu_hist") + def test_max_cat(self) -> None: self.cputest.run_max_cat("gpu_hist") diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 4e73bab31..889a7c77f 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -1,5 +1,6 @@ from random import choice from string import ascii_lowercase +from typing import Dict, Any import testing as tm import pytest import xgboost as xgb @@ -38,6 +39,9 @@ def train_result(param, dmat, num_rounds): class TestTreeMethod: + USE_ONEHOT = np.iinfo(np.int32).max + USE_PART = 1 + @given(exact_parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy) @settings(deadline=None, print_blob=True) @@ -213,10 +217,43 @@ class TestTreeMethod: def test_max_cat(self, tree_method) -> None: self.run_max_cat(tree_method) - def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): - USE_ONEHOT = np.iinfo(np.int32).max - USE_PART = 1 + def run_categorical_missing( + self, rows: int, cols: int, cats: int, tree_method: str + ) -> None: + parameters: Dict[str, Any] = {"tree_method": tree_method} + cat, label = tm.make_categorical( + n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5 + ) + Xy = xgb.DMatrix(cat, label, enable_categorical=True) + def run(max_cat_to_onehot: int): + # Test with onehot splits + parameters["max_cat_to_onehot"] = max_cat_to_onehot + + evals_result: Dict[str, Dict] = {} + booster = xgb.train( + parameters, + Xy, + num_boost_round=16, + evals=[(Xy, "Train")], + evals_result=evals_result + ) + assert tm.non_increasing(evals_result["Train"]["rmse"]) + y_predt = booster.predict(Xy) + + rmse = tm.root_mean_square(label, y_predt) + np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1]) + + # Test with OHE split + run(self.USE_ONEHOT) + + if tree_method == "gpu_hist": # fixme: Test with GPU. + return + + # Test with partition-based split + run(self.USE_PART) + + def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): onehot, label = tm.make_categorical(rows, cols, cats, True) cat, _ = tm.make_categorical(rows, cols, cats, False) @@ -226,7 +263,7 @@ class TestTreeMethod: predictor = "gpu_predictor" if tree_method == "gpu_hist" else None parameters = {"tree_method": tree_method, "predictor": predictor} # Use one-hot exclusively - parameters["max_cat_to_onehot"] = USE_ONEHOT + parameters["max_cat_to_onehot"] = self.USE_ONEHOT m = xgb.DMatrix(onehot, label, enable_categorical=False) xgb.train( @@ -260,7 +297,7 @@ class TestTreeMethod: by_grouping: xgb.callback.TrainingCallback.EvalsLog = {} # switch to partition-based splits - parameters["max_cat_to_onehot"] = USE_PART + parameters["max_cat_to_onehot"] = self.USE_PART parameters["reg_lambda"] = 0 m = xgb.DMatrix(cat, label, enable_categorical=True) xgb.train( @@ -287,27 +324,6 @@ class TestTreeMethod: ) assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping - # test with missing values - cat, label = tm.make_categorical( - n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5 - ) - Xy = xgb.DMatrix(cat, label, enable_categorical=True) - evals_result = {} - # Test with onehot splits - parameters["max_cat_to_onehot"] = USE_ONEHOT - booster = xgb.train( - parameters, - Xy, - num_boost_round=16, - evals=[(Xy, "Train")], - evals_result=evals_result - ) - assert tm.non_increasing(evals_result["Train"]["rmse"]) - y_predt = booster.predict(Xy) - - rmse = tm.root_mean_square(label, y_predt) - np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1]) - @given(strategies.integers(10, 400), strategies.integers(3, 8), strategies.integers(1, 2), strategies.integers(4, 7)) @settings(deadline=None, print_blob=True) @@ -315,3 +331,14 @@ class TestTreeMethod: def test_categorical(self, rows, cols, rounds, cats): self.run_categorical_basic(rows, cols, rounds, cats, "approx") self.run_categorical_basic(rows, cols, rounds, cats, "hist") + + @given( + strategies.integers(10, 400), + strategies.integers(3, 8), + strategies.integers(4, 7) + ) + @settings(deadline=None, print_blob=True) + @pytest.mark.skipif(**tm.no_pandas()) + def test_categorical_missing(self, rows, cols, cats): + self.run_categorical_missing(rows, cols, cats, "approx") + self.run_categorical_missing(rows, cols, cats, "hist")