Add max_cat_threshold to GPU and handle missing cat values. (#8212)
This commit is contained in:
@@ -43,7 +43,7 @@ class TestPartitionBasedSplit : public ::testing::Test {
|
||||
auto &h_vals = cuts_.cut_values_.HostVector();
|
||||
h_vals.resize(n_bins_);
|
||||
std::iota(h_vals.begin(), h_vals.end(), 0.0);
|
||||
|
||||
|
||||
cuts_.min_vals_.Resize(1);
|
||||
|
||||
hist_.Init(cuts_.TotalBins());
|
||||
@@ -97,5 +97,59 @@ class TestPartitionBasedSplit : public ::testing::Test {
|
||||
} while (std::next_permutation(sorted_idx_.begin(), sorted_idx_.end()));
|
||||
}
|
||||
};
|
||||
|
||||
inline auto MakeCutsForTest(std::vector<float> values, std::vector<uint32_t> ptrs,
|
||||
std::vector<float> min_values, int32_t device) {
|
||||
common::HistogramCuts cuts;
|
||||
cuts.cut_values_.HostVector() = values;
|
||||
cuts.cut_ptrs_.HostVector() = ptrs;
|
||||
cuts.min_vals_.HostVector() = min_values;
|
||||
|
||||
if (device >= 0) {
|
||||
cuts.cut_ptrs_.SetDevice(device);
|
||||
cuts.cut_values_.SetDevice(device);
|
||||
cuts.min_vals_.SetDevice(device);
|
||||
}
|
||||
|
||||
return cuts;
|
||||
}
|
||||
|
||||
class TestCategoricalSplitWithMissing : public testing::Test {
|
||||
protected:
|
||||
common::HistogramCuts cuts_;
|
||||
// Setup gradients and parent sum with missing values.
|
||||
GradientPairPrecise parent_sum_{1.0, 6.0};
|
||||
std::vector<GradientPairPrecise> feature_histogram_{
|
||||
{0.5, 0.5}, {0.5, 0.5}, {1.0, 1.0}, {1.0, 1.0}};
|
||||
TrainParam param_;
|
||||
|
||||
void SetUp() override {
|
||||
cuts_ = MakeCutsForTest({0.0, 1.0, 2.0, 3.0}, {0, 4}, {0.0}, -1);
|
||||
auto max_cat = *std::max_element(cuts_.cut_values_.HostVector().begin(),
|
||||
cuts_.cut_values_.HostVector().end());
|
||||
cuts_.SetCategorical(true, max_cat);
|
||||
param_.UpdateAllowUnknown(
|
||||
Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}, {"max_cat_to_onehot", "1"}});
|
||||
}
|
||||
|
||||
void CheckResult(float loss_chg, bst_feature_t split_ind, float fvalue, bool is_cat,
|
||||
bool dft_left, GradientPairPrecise left_sum, GradientPairPrecise right_sum) {
|
||||
// forward
|
||||
// it: 0, gain: 0.545455
|
||||
// it: 1, gain: 1.000000
|
||||
// it: 2, gain: 2.250000
|
||||
// backward
|
||||
// it: 3, gain: 1.000000
|
||||
// it: 2, gain: 2.250000
|
||||
// it: 1, gain: 3.142857
|
||||
ASSERT_NEAR(loss_chg, 2.97619, kRtEps);
|
||||
ASSERT_TRUE(is_cat);
|
||||
ASSERT_TRUE(std::isnan(fvalue));
|
||||
ASSERT_EQ(split_ind, 0);
|
||||
ASSERT_FALSE(dft_left);
|
||||
ASSERT_EQ(left_sum.GetHess(), 2.5);
|
||||
ASSERT_EQ(right_sum.GetHess(), parent_sum_.GetHess() - left_sum.GetHess());
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user